Full Code of invoke-ai/invoke-training for AI

main 363f83cdb5e6 cached
246 files
930.6 KB
219.8k tokens
581 symbols
1 requests
Download .txt
Showing preview only (1,010K chars total). Download the full file or copy to clipboard to get everything.
Repository: invoke-ai/invoke-training
Branch: main
Commit: 363f83cdb5e6
Files: 246
Total size: 930.6 KB

Directory structure:
gitextract_5kosyax7/

├── .github/
│   └── workflows/
│       ├── deploy.yaml
│       └── test.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── docs/
│   ├── contributing/
│   │   ├── development_environment.md
│   │   ├── directory_structure.md
│   │   ├── documentation.md
│   │   └── tests.md
│   ├── get-started/
│   │   ├── installation.md
│   │   └── quick-start.md
│   ├── guides/
│   │   ├── dataset_formats.md
│   │   ├── model_merge.md
│   │   └── stable_diffusion/
│   │       ├── dpo_lora_sd.md
│   │       ├── gnome_lora_masks_sdxl.md
│   │       ├── robocats_finetune_sdxl.md
│   │       └── textual_inversion_sdxl.md
│   ├── index.md
│   ├── reference/
│   │   └── config/
│   │       ├── index.md
│   │       ├── pipelines/
│   │       │   ├── sd_lora.md
│   │       │   ├── sd_textual_inversion.md
│   │       │   ├── sdxl_finetune.md
│   │       │   ├── sdxl_lora.md
│   │       │   ├── sdxl_lora_and_textual_inversion.md
│   │       │   └── sdxl_textual_inversion.md
│   │       └── shared/
│   │           ├── data/
│   │           │   ├── data_loader_config.md
│   │           │   └── dataset_config.md
│   │           └── optimizer_config.md
│   └── templates/
│       └── python/
│           └── material/
│               └── labels.html
├── mkdocs.yml
├── pyproject.toml
├── sample_data/
│   └── bruce_the_gnome/
│       └── data.jsonl
├── src/
│   └── invoke_training/
│       ├── __init__.py
│       ├── _shared/
│       │   ├── __init__.py
│       │   ├── accelerator/
│       │   │   ├── __init__.py
│       │   │   └── accelerator_utils.py
│       │   ├── checkpoints/
│       │   │   ├── __init__.py
│       │   │   ├── checkpoint_tracker.py
│       │   │   ├── lora_checkpoint_utils.py
│       │   │   └── serialization.py
│       │   ├── data/
│       │   │   ├── ARCHITECTURE.md
│       │   │   ├── __init__.py
│       │   │   ├── data_loaders/
│       │   │   │   ├── __init__.py
│       │   │   │   ├── dreambooth_sd_dataloader.py
│       │   │   │   ├── image_caption_flux_dataloader.py
│       │   │   │   ├── image_caption_sd_dataloader.py
│       │   │   │   ├── image_pair_preference_sd_dataloader.py
│       │   │   │   └── textual_inversion_sd_dataloader.py
│       │   │   ├── datasets/
│       │   │   │   ├── __init__.py
│       │   │   │   ├── build_dataset.py
│       │   │   │   ├── hf_image_caption_dataset.py
│       │   │   │   ├── hf_image_pair_preference_dataset.py
│       │   │   │   ├── image_caption_dir_dataset.py
│       │   │   │   ├── image_caption_jsonl_dataset.py
│       │   │   │   ├── image_dir_dataset.py
│       │   │   │   ├── image_pair_preference_dataset.py
│       │   │   │   └── transform_dataset.py
│       │   │   ├── samplers/
│       │   │   │   ├── __init__.py
│       │   │   │   ├── aspect_ratio_bucket_batch_sampler.py
│       │   │   │   ├── batch_offset_sampler.py
│       │   │   │   ├── concat_sampler.py
│       │   │   │   ├── interleaved_sampler.py
│       │   │   │   └── offset_sampler.py
│       │   │   ├── transforms/
│       │   │   │   ├── __init__.py
│       │   │   │   ├── caption_prefix_transform.py
│       │   │   │   ├── concat_fields_transform.py
│       │   │   │   ├── constant_field_transform.py
│       │   │   │   ├── drop_field_transform.py
│       │   │   │   ├── flux_image_transform.py
│       │   │   │   ├── load_cache_transform.py
│       │   │   │   ├── sd_image_transform.py
│       │   │   │   ├── shuffle_caption_transform.py
│       │   │   │   ├── template_caption_transform.py
│       │   │   │   └── tensor_disk_cache.py
│       │   │   └── utils/
│       │   │       ├── __init__.py
│       │   │       ├── aspect_ratio_bucket_manager.py
│       │   │       ├── resize.py
│       │   │       └── resolution.py
│       │   ├── flux/
│       │   │   ├── encoding_utils.py
│       │   │   ├── lora_checkpoint_utils.py
│       │   │   ├── model_loading_utils.py
│       │   │   └── validation.py
│       │   ├── optimizer/
│       │   │   ├── __init__.py
│       │   │   └── optimizer_utils.py
│       │   ├── stable_diffusion/
│       │   │   ├── __init__.py
│       │   │   ├── base_model_version.py
│       │   │   ├── checkpoint_utils.py
│       │   │   ├── lora_checkpoint_utils.py
│       │   │   ├── min_snr_weighting.py
│       │   │   ├── model_loading_utils.py
│       │   │   ├── textual_inversion.py
│       │   │   ├── tokenize_captions.py
│       │   │   └── validation.py
│       │   ├── tools/
│       │   │   ├── __init__.py
│       │   │   └── generate_images.py
│       │   └── utils/
│       │       ├── import_xformers.py
│       │       └── jsonl.py
│       ├── config/
│       │   ├── __init__.py
│       │   ├── base_pipeline_config.py
│       │   ├── config_base_model.py
│       │   ├── data/
│       │   │   ├── __init__.py
│       │   │   ├── data_loader_config.py
│       │   │   └── dataset_config.py
│       │   ├── optimizer/
│       │   │   ├── __init__.py
│       │   │   └── optimizer_config.py
│       │   └── pipeline_config.py
│       ├── model_merge/
│       │   ├── __init__.py
│       │   ├── extract_lora.py
│       │   ├── merge_models.py
│       │   ├── merge_tasks_to_base.py
│       │   ├── scripts/
│       │   │   ├── extract_lora_from_model_diff.py
│       │   │   ├── merge_lora_into_model.py
│       │   │   ├── merge_models.py
│       │   │   └── merge_task_models_to_base_model.py
│       │   └── utils/
│       │       ├── __init__.py
│       │       ├── normalize_weights.py
│       │       └── parse_model_arg.py
│       ├── pipelines/
│       │   ├── __init__.py
│       │   ├── _experimental/
│       │   │   └── sd_dpo_lora/
│       │   │       ├── config.py
│       │   │       └── train.py
│       │   ├── callbacks.py
│       │   ├── flux/
│       │   │   └── lora/
│       │   │       ├── __init__.py
│       │   │       ├── config.py
│       │   │       └── train.py
│       │   ├── invoke_train.py
│       │   ├── stable_diffusion/
│       │   │   ├── __init__.py
│       │   │   ├── lora/
│       │   │   │   ├── __init__.py
│       │   │   │   ├── config.py
│       │   │   │   └── train.py
│       │   │   └── textual_inversion/
│       │   │       ├── __init__.py
│       │   │       ├── config.py
│       │   │       └── train.py
│       │   └── stable_diffusion_xl/
│       │       ├── __init__.py
│       │       ├── finetune/
│       │       │   ├── __init__.py
│       │       │   ├── config.py
│       │       │   └── train.py
│       │       ├── lora/
│       │       │   ├── __init__.py
│       │       │   ├── config.py
│       │       │   └── train.py
│       │       ├── lora_and_textual_inversion/
│       │       │   ├── __init__.py
│       │       │   ├── config.py
│       │       │   └── train.py
│       │       └── textual_inversion/
│       │           ├── __init__.py
│       │           ├── config.py
│       │           └── train.py
│       ├── sample_configs/
│       │   ├── _experimental/
│       │   │   ├── sd_dpo_lora_pickapic_1x24gb.yaml
│       │   │   └── sd_dpo_lora_refinement_pokemon_1x24gb.yaml
│       │   ├── flux_lora_1x40gb.yaml
│       │   ├── sd_lora_baroque_1x8gb.yaml
│       │   ├── sd_textual_inversion_gnome_1x8gb.yaml
│       │   ├── sdxl_finetune_baroque_1x24gb.yaml
│       │   ├── sdxl_finetune_robocats_1x24gb.yaml
│       │   ├── sdxl_lora_and_ti_gnome_1x24gb.yaml
│       │   ├── sdxl_lora_baroque_1x24gb.yaml
│       │   ├── sdxl_lora_baroque_1x8gb.yaml
│       │   ├── sdxl_lora_masks_gnome_1x24gb.yaml
│       │   ├── sdxl_textual_inversion_gnome_1x24gb.yaml
│       │   └── sdxl_textual_inversion_masks_gnome_1x24gb.yaml
│       ├── scripts/
│       │   ├── __init__.py
│       │   ├── _experimental/
│       │   │   ├── auto_caption/
│       │   │   │   └── auto_caption_images.py
│       │   │   ├── masks/
│       │   │   │   ├── clipseg.py
│       │   │   │   ├── generate_masks.py
│       │   │   │   └── generate_masks_for_jsonl_dataset.py
│       │   │   └── rank_images.py
│       │   ├── convert_sd_lora_to_kohya_format.py
│       │   ├── invoke_generate_images.py
│       │   ├── invoke_train.py
│       │   ├── invoke_train_ui.py
│       │   ├── invoke_visualize_data_loading.py
│       │   └── utils/
│       │       └── image_dir_dataset.py
│       └── ui/
│           ├── __init__.py
│           ├── app.py
│           ├── config_groups/
│           │   ├── __init__.py
│           │   ├── aspect_ratio_bucket_config_group.py
│           │   ├── base_pipeline_config_group.py
│           │   ├── dataset_config_group.py
│           │   ├── flux_lora_config_group.py
│           │   ├── image_caption_sd_data_loader_config_group.py
│           │   ├── optimizer_config_group.py
│           │   ├── sd_lora_config_group.py
│           │   ├── sd_textual_inversion_config_group.py
│           │   ├── sdxl_finetune_config_group.py
│           │   ├── sdxl_lora_and_textual_inversion_config_group.py
│           │   ├── sdxl_lora_config_group.py
│           │   ├── sdxl_textual_inversion_config_group.py
│           │   ├── textual_inversion_sd_data_loader_config_group.py
│           │   └── ui_config_element.py
│           ├── gradio_blocks/
│           │   ├── header.py
│           │   └── pipeline_tab.py
│           ├── index.html
│           ├── pages/
│           │   ├── data_page.py
│           │   └── training_page.py
│           └── utils/
│               ├── prompts.py
│               └── utils.py
└── tests/
    └── invoke_training/
        ├── _shared/
        │   ├── __init__.py
        │   ├── checkpoints/
        │   │   ├── test_checkpoint_tracker.py
        │   │   └── test_serialization.py
        │   ├── data/
        │   │   ├── __init__.py
        │   │   ├── data_loaders/
        │   │   │   ├── __init__.py
        │   │   │   ├── test_dreambooth_sd_dataloader.py
        │   │   │   ├── test_image_caption_sd_dataloader.py
        │   │   │   ├── test_image_pair_preference_sd_dataloader.py
        │   │   │   └── test_textual_inversion_sd_dataloader.py
        │   │   ├── dataset_fixtures.py
        │   │   ├── datasets/
        │   │   │   ├── __init__.py
        │   │   │   ├── test_hf_image_caption_dataset.py
        │   │   │   ├── test_hf_image_pair_preference_dataset.py
        │   │   │   ├── test_image_caption_dir_dataset.py
        │   │   │   ├── test_image_caption_jsonl_dataset.py
        │   │   │   ├── test_image_dir_dataset.py
        │   │   │   ├── test_image_pair_preference_dataset.py
        │   │   │   └── test_transform_dataset.py
        │   │   ├── samplers/
        │   │   │   ├── __init__.py
        │   │   │   ├── test_aspect_ratio_bucket_batch_sampler.py
        │   │   │   ├── test_batch_offset_sampler.py
        │   │   │   ├── test_concat_sampler.py
        │   │   │   ├── test_interleaved_sampler.py
        │   │   │   └── test_offset_sampler.py
        │   │   ├── transforms/
        │   │   │   ├── __init__.py
        │   │   │   ├── test_caption_prefix_transform.py
        │   │   │   ├── test_concat_fields_transform.py
        │   │   │   ├── test_constant_field_transform.py
        │   │   │   ├── test_drop_field_transform.py
        │   │   │   ├── test_load_cache_transform.py
        │   │   │   ├── test_sd_image_transform.py
        │   │   │   ├── test_shuffle_caption_transform.py
        │   │   │   ├── test_template_caption_transform.py
        │   │   │   └── test_tensor_disk_cache.py
        │   │   └── utils/
        │   │       ├── __init__.py
        │   │       ├── test_aspect_ratio_bucket_manager.py
        │   │       ├── test_resize.py
        │   │       └── test_resolution.py
        │   ├── stable_diffusion/
        │   │   ├── __init__.py
        │   │   ├── test_base_model_version.py
        │   │   ├── test_lora_checkpoint_utils.py
        │   │   ├── test_model_loading_utils.py
        │   │   ├── test_textual_inversion.py
        │   │   └── ti_embedding_checkpoint_fixture.py
        │   └── utils/
        │       └── test_jsonl.py
        ├── config/
        │   └── pipelines/
        │       └── test_pipeline_config.py
        ├── model_merge/
        │   ├── __init__.py
        │   ├── test_merge_models.py
        │   ├── test_merge_tasks_to_base.py
        │   └── utils.py
        └── ui/
            └── utils/
                └── test_prompts.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .github/workflows/deploy.yaml
================================================
name: Deploy invoke-training docs

on:
  push:
    branches:
      - main

permissions:
  contents: write

jobs:
  deploy:
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v4
      - name: Configure Git Credentials
        run: |
          git config user.name github-actions[bot]
          git config user.email 41898282+github-actions[bot]@users.noreply.github.com
      - uses: actions/setup-python@v4
        with:
          python-version: "3.10"
          cache: pip
          cache-dependency-path: pyproject.toml
      - name: Install dependencies
        run: |
          python -m pip install --upgrade pip
          python -m pip install .[test]
      - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV 
      - uses: actions/cache@v3
        with:
          key: mkdocs-material-${{ env.cache_id }}
          path: .cache
          restore-keys: |
            mkdocs-material-
      - run: mkdocs gh-deploy --force


================================================
FILE: .github/workflows/test.yaml
================================================
name: Test invoke-training

on:
  push:
    branches:
      - main
  pull_request:
  workflow_dispatch:

jobs:
  build:

    runs-on: ubuntu-latest
    strategy:
      fail-fast: false
      matrix:
        python-version: ["3.12"]

    steps:
    - uses: actions/checkout@v4
    - name: Set up Python ${{ matrix.python-version }}
      uses: actions/setup-python@v5
      with:
        python-version: ${{ matrix.python-version }}
        cache: pip
        cache-dependency-path: pyproject.toml
    - name: Install dependencies
      run: |
        python -m pip install --upgrade pip
        python -m pip install .[test]
    - name: Ruff lint
      run: |
        ruff check --output-format=github .
    - name: Ruff format
      run: |
        ruff format --check .
    - name: Test with pytest
      run: |
        pytest tests --junitxml=junit/test-results-${{ matrix.python-version }}.xml -m "not cuda and not loads_model"
    - name: Upload pytest test results
      uses: actions/upload-artifact@v4
      with:
        name: pytest-results-${{ matrix.python-version }}
        path: junit/test-results-${{ matrix.python-version }}.xml
      # Use always() to always run this step to publish test results when there are test failures.
      if: ${{ always() }}


================================================
FILE: .gitignore
================================================
/output/
/test_configs/
/data/

# pyenv
.python-version

# VSCode
.vscode/

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
junit/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
#   For a library or package, you might want to ignore these files since the code is
#   intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# poetry
#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
#   This is especially recommended for binary packages to ensure reproducibility, and is more
#   commonly ignored for libraries.
#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
#   in version control.
#   https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
#  and can be added to the global gitignore or merged into this file.  For a more nuclear
#  option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.aider*


================================================
FILE: .pre-commit-config.yaml
================================================
# See https://pre-commit.com/ for usage and config.
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
  # Ruff version.
  rev: v0.1.7
  hooks:
    # Run the linter.
    - id: ruff
    # Run the formatter.
    - id: ruff-format


================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
# invoke-training

A library for training custom Stable Diffusion models (fine-tuning, LoRA training, textual inversion, etc.) that can be used in [InvokeAI](https://github.com/invoke-ai/InvokeAI).

> [!WARNING] > `invoke-training` is still under active development, and breaking changes are likely. Full backwards compatibility will not be guaranteed until v1.0.0.
> In the meantime, I recommend pinning to a specific commit hash.

## Documentation

<https://invoke-ai.github.io/invoke-training/>

## Training Modes

- Stable Diffusion
  - LoRA
  - DreamBooth LoRA
  - Textual Inversion
- Stable Diffusion XL
  - Full finetuning
  - LoRA
  - DreamBooth LoRA
  - Textual Inversion
  - LoRA and Textual Inversion

More training modes coming soon!

## Installation

See the [Installation](https://invoke-ai.github.io/invoke-training/get-started/installation/) section of the documentation.

## Quick Start

`invoke-training` pipelines can be configured and launched from either the CLI or the GUI.

### CLI

Run training via the CLI with type-checked YAML configuration files for maximum control:

```bash
invoke-train --cfg-file src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml
```

### GUI

Run training via the GUI for a simpler starting point.

```bash
invoke-train-ui

# Or, you can optionally override the default host and port:
invoke-train-ui --host 0.0.0.0 --port 1234
```

## Features

Training progress can be monitored with [Tensorboard](https://www.tensorflow.org/tensorboard):
![Screenshot of the Tensorboard UI showing validation images.](docs/images/tensorboard_val_images_screenshot.png)
_Validation images in the Tensorboard UI._

All trained models are compatible with InvokeAI:

![Screenshot of the InvokeAI UI with an example of a Yoda pokemon generated using a Pokemon LoRA model.](docs/images/invokeai_yoda_pokemon_lora.png)
_Example image generated with the prompt "A cute yoda pokemon creature." and a trained Pokemon LoRA._

## Contributing

Contributors are welcome. For developer guidance, see the [Contributing](https://invoke-ai.github.io/invoke-training/contributing/development_environment/) section of the documentation.


================================================
FILE: docs/contributing/development_environment.md
================================================
# Development Environment Setup

See the [developer installation instructions](../get-started/installation.md#developer-installation).


================================================
FILE: docs/contributing/directory_structure.md
================================================
# Directory Structure

```bash
invoke-training/
├── README.md
├── docs/
├── src/
│   └── invoke-training/
│       ├── _shared/ # Utilities shared across multiple pipelines. Hight unit test coverage.
│       ├── config/ # Config structures shared by multiple pipelines.
│       ├── pipelines/ # Each pipeline is isolated in it's own directory with a train.py and config.py.
│       │   ├── stable_diffusion/
│       │   │   ├── lora/
│       │   │   │   ├── config.py
│       │   │   │   └── train.py
│       │   │   └── textual_inversion/
│       │   │       └── ...
│       │   ├── stable_diffusion_xl/
│       │   └── ...
│       └── scripts/ # Main entrypoints.
└── tests/ # Mirrors src/ directory.
```


================================================
FILE: docs/contributing/documentation.md
================================================
# Documentation

The documentation site is generated using [mkdocs](https://www.mkdocs.org/) and [mkdocstrings-python](https://mkdocstrings.github.io/python/).

To view your documentation changes locally, run `mkdocs serve`.


================================================
FILE: docs/contributing/tests.md
================================================
# Tests

Run all unit tests with:

```bash
pytest tests/
```

There are some test 'markers' defined in [pyproject.toml](https://github.com/invoke-ai/invoke-training/blob/main/pyproject.toml) that can be used to skip some tests. For example, the following command skips tests that require a GPU or require downloading model weights:

```bash
pytest tests/ -m "not cuda and not loads_model"
```


================================================
FILE: docs/get-started/installation.md
================================================
# Installation

## Requirements

1. Python 3.10, 3.11 and 3.12 are currently supported. Check your Python version by running `python -V`.
2. An NVIDIA GPU with >= 8 GB VRAM is recommended for model training.

## Basic Installation

0. Open your terminal and navigate to the directory where you want to clone the `invoke-training` repo.
1. Clone the repo:

   ```bash
   git clone https://github.com/invoke-ai/invoke-training.git
   ```

2. Create and activate a python [virtual environment](https://docs.python.org/3/library/venv.html#creating-virtual-environments). This creates an isolated environment for `invoke-training` and its dependencies that won't interfere with other python environments on your system, including any installations of [InvokeAI](https://www.github.com/invoke-ai/invokeai).

   ```bash
   # Navigate to the invoke-training directory.
   cd invoke-training

   # Create a new virtual environment named `invoketraining`.
   python -m venv invoketraining

   # Activate the new virtual environment.
   # On Windows:
   .\invoketraining\Scripts\activate
   # On MacOS / Linux:
   source invoketraining/bin/activate
   ```

3. Install `invoke-training` and its dependencies. Run the appropriate install command for your system.

   ```bash
   # A recent version of pip is required, so first upgrade pip:
   python -m pip install --upgrade pip

   # Install - Windows or Linux with a Nvidia GPU:
   pip install ".[test]" --extra-index-url https://download.pytorch.org/whl/cu126

   # Install - Linux with no GPU:
   pip install ".[test]" --extra-index-url https://download.pytorch.org/whl/cpu

   # Install - All other systems:
   pip install ".[test]"
   ```

In the future, before you run `invoke-training`, you must activate the virtual environment you created during installation, using the same command you used during installation.

## Developer Installation

Consider forking the repo if you plan to contribute code changes.

Follow the above installation instructions, cloning your fork instead of this repo if you made a fork.

Next, we suggest setting up the repo's pre-commit hooks to automatically format and lint your contributions:

1. (_Optional_) Install the pre-commit hooks: `pre-commit install`. This will run static analysis tools (ruff) on `git commit`.
2. (_Optional_) Setup `ruff` in your IDE of choice.


================================================
FILE: docs/get-started/quick-start.md
================================================
# Quick Start

`invoke-training` has both a GUI and a CLI (for advanced users). The instructions for getting started with both options can be found on this page.

There is also a video introduction to `invoke-training`:

<iframe width="560" height="315" src="https://www.youtube.com/embed/OZIz2vvtlM4?si=iR73F0IhlsolyYAl" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" referrerpolicy="strict-origin-when-cross-origin" allowfullscreen></iframe>

## Quick Start - GUI

### 1. Installation

Follow the [`invoke-training` installation instructions](./installation.md).

### 2. Launch the GUI

Activate the virtual environment you created during installation, using the same command you used during installation.

You'll need to do this every time you run `invoke-training`.

```bash
# From the invoke-training directory:
invoke-train-ui

# Or, you can optionally override the default host and port:
invoke-train-ui --host 0.0.0.0 --port 1234
```

Access the GUI in your browser at the URL printed to the console.

### 3. Configure the training job

Select the desired training pipeline type in the top-level tab.

For this tutorial, we don't need to change any of the configuration values. The preset configuration should work well.

### 4. Generate the YAML configuration

Click on 'Generate Config' to generate a YAML configuration file. This YAML configuration file could be used to launch the training job from the CLI, if desired.

### 5. Start training

Click on the 'Start Training' and check your terminal for progress logs.

### 6. Monitor training

Monitor the training process with Tensorboard by running `tensorboard --logdir output/` and visiting [localhost:6006](http://localhost:6006) in your browser. Here you can see generated validation images throughout the training process.

![Screenshot of the Tensorboard UI showing validation images.](../images/tensorboard_val_images_screenshot.png)
_Validation images in the Tensorboard UI._

### 7. Invokeai

Select a checkpoint based on the quality of the generated images.

If you haven't already, setup [InvokeAI](https://github.com/invoke-ai/InvokeAI) by following its documentation.

Copy your selected LoRA checkpoint into your `${INVOKEAI_ROOT}/autoimport/lora` directory. For example:

```bash
# Note: You will have to replace the timestamp in the checkpoint path.
cp output/1691088769.5694647/checkpoint_epoch-00000002.safetensors ${INVOKEAI_ROOT}/autoimport/lora/pokemon_epoch-00000002.safetensors
```

You can now use your trained Pokemon LoRA in the InvokeAI UI! 🎉

![Screenshot of the InvokeAI UI with an example of a Yoda pokemon generated using a Pokemon LoRA model.](../images/invokeai_yoda_pokemon_lora.png)
_Example image generated with the prompt "A cute yoda pokemon creature." and Pokemon LoRA._

## Quick Start - CLI

### 1. Installation

Follow the [`invoke-training` installation instructions](./installation.md).

### 2. Training

Activate the virtual environment you created during installation, using the same command you used during installation.

You'll need to do this every time you run `invoke-training`.

See the [Textual Inversion - SDXL](../guides/stable_diffusion/textual_inversion_sdxl.md) tutorial for instructions on how to train a model via the CLI.


================================================
FILE: docs/guides/dataset_formats.md
================================================
# Dataset Formats

`invoke-training` supports the following dataset formats:

- `IMAGE_CAPTION_JSONL_DATASET`: A local image-caption dataset described by a single `.jsonl` file.
- `IMAGE_CAPTION_DIR_DATASET`: A local directory of images with associated `.txt` caption files.
- `IMAGE_DIR_DATASET`: A local directory of images (without captions).
- `HF_HUB_IMAGE_CAPTION_DATASET`: A Hugging Face Hub dataset containing images and captions.

See the documentation for a particular training pipeline to see which dataset formats it supports.

The following sections explain each of these formats in more detail.

## `IMAGE_CAPTION_JSONL_DATASET`

Config documentation: [ImageCaptionJsonlDatasetConfig][invoke_training.config.data.dataset_config.ImageCaptionJsonlDatasetConfig]

A `IMAGE_CAPTION_JSONL_DATASET` consists of a single `.jsonl` file containing image paths and associated captions.

Sample directory structure:
```bash
my_custom_dataset/
├── data.jsonl
└── train/
    ├── 0001.png
    ├── 0002.png
    ├── 0003.png
    └── ...
```

The contents of `data.jsonl` would be:
```json
{"file_name": "train/0001.png", "text": "This is a caption describing image 0001."}
{"file_name": "train/0002.png", "text": "This is a caption describing image 0002."}
{"file_name": "train/0003.png", "text": "This is a caption describing image 0003."}
```

The image file paths can be either absolute paths, or relative to the `.jsonl` file.

Finally, this dataset can be used with the following pipeline dataset configuration:
```yaml
type: IMAGE_CAPTION_JSONL_DATASET
jsonl_path: /path/to/my_custom_dataset/metadata.jsonl
image_column: file_name
caption_column: text
```

A useful characteristic of this dataset format is that a `.jsonl` file can reference an image file anywhere on the local disk. It is common to maintain multiple `.jsonl` datasets that reference some of the same images without needing multiple copies of those images on disk.

## `IMAGE_CAPTION_DIR_DATASET`

Config documentation: [ImageCaptionDirDataset][invoke_training.config.data.dataset_config.ImageCaptionDirDatasetConfig]

A `IMAGE_CAPTION_DIR_DATASET` consists of a directory of image files and corresponding `.txt` caption files of the same name.

Sample directory structure:
```bash
my_custom_dataset/
├── 0001.png
├── 0001.txt
├── 0002.jpg
├── 0002.txt
├── 0003.png
├── 0003.txt
└── ...
```

Each `.txt` file should contain a caption on the first line of the file. Here are the sample contents of `0001.txt`:
```txt title="0001.txt"
this is a caption for example 0001
```

This dataset can be used with the following pipeline dataset configuration:
```yaml
type: IMAGE_CAPTION_DIR_DATASET
dataset_dir: /path/to/my_custom_dataset
```

## `IMAGE_DIR_DATASET`

Config documentation: [ImageDirDataset][invoke_training.config.data.dataset_config.ImageDirDatasetConfig]

A `IMAGE_DIR_DATASET` consists of a single directory of images (without captions).

Sample directory structure:
```bash
my_custom_dataset/
├── 0001.png
├── 0002.jpg
├── 0003.png
└── ...
```

This dataset can be used with the following pipeline dataset configuration:
```yaml
type: IMAGE_DIR_DATASET
dataset_dir: /path/to/my_custom_dataset
```

## `HF_HUB_IMAGE_CAPTION_DATASET`

Config documentation: [HFHubImageCaptionDatasetConfig][invoke_training.config.data.dataset_config.HFHubImageCaptionDatasetConfig]

The `HF_HUB_IMAGE_CAPTION_DATASET` dataset format can be used to access publicly datasets on the [Hugging Face Hub](https://huggingface.co/datasets). You can filter for the `Text-to-Image` task to find relevant datasets that contain both an image column and a caption column. [lambdalabs/pokemon-blip-captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) is a popular choice if you're not sure where to start.


================================================
FILE: docs/guides/model_merge.md
================================================
# Model Merging

`invoke-training` provides utility scripts for several common model merging workflows. This page contains a summary of the available tools.

## `extract_lora_from_model_diff.py`

Extract a LoRA model that represents the difference between two base models.

Note that the extracted LoRA model is a lossy representation of the difference between the models, so some degradation in quality is expected.

For usage docs, run:
```bash
python src/invoke_training/model_merge/scripts/extract_lora_from_model_diff.py -h
```

## `merge_lora_into_model.py`

Merge a LoRA model into a base model to produce a new base model.

For usage docs, run:
```bash
python src/invoke_training/model_merge/scripts/merge_lora_into_model.py -h
```

## `merge_models.py`

Merge 2 or more base models to produce a single base model (using either LERP or SLERP). This is a simple merge strategy that merges all model weights in the same way.

For usage docs, run:
```bash
python src/invoke_training/model_merge/scripts/merge_models.py -h
```

## `merge_task_models_to_base_model.py`

Merge 1 or more task-specific base models into a single starting base model (using either [TIES](https://arxiv.org/abs/2306.01708) or [DARE](https://arxiv.org/abs/2311.03099)). This merge strategy aims to preserve the task-specific behaviors of the task models while making only small changes to the original base model. This approach enables multiple task models to be merged without excessive interference between them.

If you want to merge a task-specific LoRA into a base model using this strategy, first use `merge_lora_into_model.py` to produce a task-specific base model, then merge that new base model using this strategy.

For usage docs, run:
```bash
python src/invoke_training/model_merge/scripts/merge_task_models_to_base_model.py -h
```


================================================
FILE: docs/guides/stable_diffusion/dpo_lora_sd.md
================================================
# (Experimental) Diffusion DPO - SD

!!! tip "Experimental"
    The Diffusion Direct Preference Optimization training pipeline is still experimental. Support may be dropped at any time.

This tutorial walks through some initial experiments around using Diffusion Direct Preference Optimization (DPO) ([paper](https://arxiv.org/abs/2311.12908)) to train Stable Diffusion LoRA models.


## Experiment 1: `pickapic_v2` LoRA Training

The Diffusion-DPO paper does full model fine-tuning on the [pickapic_v2](https://huggingface.co/datasets/yuvalkirstain/pickapic_v2) dataset, which consists of roughly 1M AI-generated image pairs with preference annotations. In this experiment, we attempt to fine-tune a Stable Diffusion LoRA model using a small subset of the pickapic_v2 dataset.

Run this experiment with the following command:
```bash
invoke-train -c src/invoke_training/sample_configs/_experimental/sd_dpo_lora_pickapic_1x24gb.yaml
```

Here is a cherry-picked example of a prompt for which this training process was clearly beneficial.
Prompt: "*A galaxy-colored figurine is floating over the sea at sunset, photorealistic*"

| Before DPO Training | After DPO Training (same seed)|
| - | - |
| ![Sample image before DPO training.](../../images/dpo/before_dpo.jpg) | ![Sample image after DPO training.](../../images/dpo/after_dpo.jpg) |

## Experiment 2: LoRA Model Refinement

As a second experiment, we attempt the following workflow:

1. Train a Stable Diffusion LoRA model on a particular style.
2. Generate pairs of images of the character with the trained LoRA model.
3. Annotate the preferred image from each pair.
4. Apply Diffusion-DPO to the preference-annotated pairs to further fine-tune the LoRA model.

Note: The steps listed below are pretty rough. They are included primarily for reference for someone looking to resume this line of work in the future.

### 1. Train a style LoRA

```bash
invoke-train -c src/invoke_training/sample_configs/sd_lora_pokemon_1x8gb.yaml
```

### 2. Generate images

Prepare ~100 relevant prompts that will be used to generate training data with the freshly-trained LoRA model. Add the prompts to a `.txt` file - one prompt per line.

Example prompts:
```txt
a cute orange pokemon character with pointy ears
a drawing of a purple fish
a cartoon blob with a smile on its face
a drawing of a snail with big eyes
...
```

```bash
# Convert the LoRA checkpoint of interest to Kohya format.
# You will have to change the path timestamps in this example command.
# TODO(ryand): This manual conversion shouldn't be necessary.
python src/invoke_training/scripts/convert_sd_lora_to_kohya_format.py \
  --src-ckpt-dir output/sd_lora_pokemon/1704824279.2765746/checkpoint_epoch-00000003/ \
  --dst-ckpt-file output/sd_lora_pokemon/1704824279.2765746/checkpoint_epoch-00000003_kohya.safetensors

# Generate 2 pairs of images for each prompt.
invoke-generate-images \
  -o output/pokemon_pairs \
  -m runwayml/stable-diffusion-v1-5 \
  -v fp16 \
  -l output/sd_lora_pokemon/1704824279.2765746/checkpoint_epoch-00000003_kohya.safetensors \
  --sd-version SD \
  --prompt-file path/to/prompts.txt \
  --set-size 2 \
  --num-sets 2 \
  --height 512 \
  --width 512
```

### 3. Annotate the image pair preferences

Launch the gradio UI for selecting image pair preferences.

```bash
# Note: rank_images.py accepts a full training pipeline config, but only uses the dataset configuration.
python src/invoke_training/scripts/_experimental/rank_images.py -c src/invoke_training/sample_configs/_experimental/sd_dpo_lora_refinement_pokemon_1x24gb.yaml
```

After completing the pair annotations, click "Save Metadata" and move the resultant metadata file to your image data directory (e.g. `output/pokemon_pairs/metadata.jsonl`).

### 4. Run Diffusion-DPO

```bash
invoke-train -c src/invoke_training/sample_configs/_experimental/sd_dpo_lora_refinement_pokemon_1x24gb.yaml
```

================================================
FILE: docs/guides/stable_diffusion/gnome_lora_masks_sdxl.md
================================================
# LoRA with Masks - SDXL

This tutorial explains how to prepare masks for an image dataset and then use that dataset to train an SDXL LoRA model.

Masks can be used to weight regions of images in a dataset to control how much they contribute to the training process. In this tutorial we will use masks to train on a small dataset of images of Bruce the Gnome (4 images). With such a small dataset, there is a high risk of overfitting to the background elements from the images. We will use masks to avoid this problem ond focus only on the object of interest.

## 1 - Dataset Preparation

For this tutorial, we'll use a dataset consisting of 4 images of Bruce the Gnome:

| | |
| - | - |
| ![bruce_the_gnome dataset image 1.](../../images/bruce_the_gnome/001.jpg) | ![bruce_the_gnome dataset image 2.](../../images/bruce_the_gnome/002.jpg) |
| ![bruce_the_gnome dataset image 3.](../../images/bruce_the_gnome/003.jpg) | ![bruce_the_gnome dataset image 4.](../../images/bruce_the_gnome/004.jpg) |

This sample dataset is included in the invoke-training repo under [sample_data/bruce_the_gnome](https://github.com/invoke-ai/invoke-training/tree/main/sample_data/bruce_the_gnome).

## 2 - Generate Masks

Use the `generate_masks_for_jsonl_dataset.py` script to generate masks for your dataset based on a single prompt. In this case we are using the prompt `"a stuffed gnome"`:
```bash
python src/invoke_training/scripts/_experimental/masks/generate_masks_for_jsonl_dataset.py \
  --in-jsonl sample_data/bruce_the_gnome/data.jsonl \
  --out-jsonl sample_data/bruce_the_gnome/data_masks.jsonl \
  --prompt "a stuffed gnome"
```

The mask generation script will produce the following outputs:

- A directory of generated masks: `sample_data/bruce_the_gnome/masks/`
- A new `.jsonl` file that references the mask images: `sample_data/bruce_the_gnome/data_masks.jsonl`

## 3 - Review the Generated Masks

Review the generated masks to make sure that the target regions were masked. You may need to adjust the prompt and re-generate the masks to achieve the desired result. Alternatively, you can edit the masks manually. The masks are simply single-channel grayscale images (0=background, 255=foreground).

Here are some examples of the masks that we just generated:

| | |
| - | - |
| ![bruce_the_gnome dataset image 1.](../../images/bruce_the_gnome/001.jpg) | ![bruce_the_gnome dataset image 1 mask.](../../images/bruce_masks/001_mask.png) |
| ![bruce_the_gnome dataset image 2.](../../images/bruce_the_gnome/002.jpg) | ![bruce_the_gnome dataset image 2 mask.](../../images/bruce_masks/002_mask.png) |

## 4 - Configuration

Below is the training configuration that we'll use for this tutorial.

Raw config file: [src/invoke_training/sample_configs/sdxl_lora_masks_gnome_1x24gb.yaml](https://github.com/invoke-ai/invoke-training/blob/main/src/invoke_training/sample_configs/sdxl_lora_masks_gnome_1x24gb.yaml).


```yaml title="sdxl_lora_masks_gnome_1x24gb.yaml"
--8<-- "src/invoke_training/sample_configs/sdxl_lora_masks_gnome_1x24gb.yaml"
```

Full documentation of all of the configuration options is here: [LoRA SDXL Config](../../reference/config/pipelines/sdxl_lora.md)

There are few things to note about this training config:

- We set `use_masks: True` in order to use the masks that we generated. This configuration is only compatible with datasets that have mask data.
- The `learning_rate`, `max_train_steps`, `save_every_n_steps`, and `validate_every_n_steps` are all _lower_ than typical for an SDXL LoRA training pipeline. The combination of masking with the small dataset size cause training to progress very quickly. These configuration fields were all adjusted accordingly to avoid overfitting.

## 5 - Start Training

Launch the training run.
```bash
# From inside the invoke-training/ source directory:
invoke-train -c src/invoke_training/sample_configs/sdxl_lora_masks_gnome_1x24gb.yaml
```

Training takes ~30 mins on an NVIDIA RTX 4090.

## 4 - Monitor

In a new terminal, launch Tensorboard to monitor the training run:
```bash
tensorboard --logdir output/
```
Access Tensorboard at [localhost:6006](http://localhost:6006) in your browser.

Sample images will be logged to Tensorboard so that you can see how the model is evolving.

Once training is complete, select the model checkpoint that produces the best visual results. For this tutorial, we'll use the checkpoint from step 300:

![Screenshot of the Tensorboard UI showing the validation images for step 300.](../../images/bruce_masks/bruce_masks_step_300.jpg)
*Screenshot of the Tensorboard UI showing the validation images for epoch 300. The validation prompt was: "A stuffed gnome at the beach with a pina colada in its hand.".*


## 6 - Import into InvokeAI

If you haven't already, setup [InvokeAI](https://github.com/invoke-ai/InvokeAI) by following its documentation.

Import your trained LoRA model from the 'Models' tab.

Congratulations, you can now use your new Bruce-the-Gnome model! 🎉


================================================
FILE: docs/guides/stable_diffusion/robocats_finetune_sdxl.md
================================================
# Finetune - SDXL

This tutorial explains how to do a full finetune training run on a [Stable Diffusion XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) base model.

## 0 - Prerequisites

Full model finetuning is more compute-intensive than parameter-efficient finetuning alternatives (e.g. LoRA or Textual Inversion). This tutorial requires a minimum of 24GB of GPU VRAM.

## 1 - Dataset Preparation

For this tutorial, we will use a dataset consisting of 14 images of robocats. The images were auto-captioned. Here are some sample images from the dataset, including their captions:

| | |
| - | - |
| ![A white robot with blue eyes and a yellow nose sits on a rock, gazing at the camera, with a pink tree and a white cat in the background.](../../images/robocats/sipu3h70yb87rju8a8l36ejr.jpg) | ![A white cat with green eyes and a blue collar sits on a moss-covered rock in a forest, gazing directly at the camera.](../../images/robocats/v2h3ld50bi9owhhzo9gf9utg.jpg) |
| *A white robot with blue eyes and a yellow nose sits on a rock, gazing at the camera, with a pink tree and a white cat in the background.* | *A white cat with green eyes and a blue collar sits on a moss-covered rock in a forest, gazing directly at the camera.* |

## 2 - Configuration

Below is the training configuration that we'll use for this tutorial.

Raw config file: [src/invoke_training/sample_configs/sdxl_finetune_robocats_1x24gb.yaml](https://github.com/invoke-ai/invoke-training/blob/main/src/invoke_training/sample_configs/sdxl_finetune_robocats_1x24gb.yaml).


```yaml title="sdxl_finetune_robocats_1x24gb.yaml"
--8<-- "src/invoke_training/sample_configs/sdxl_finetune_robocats_1x24gb.yaml"
```

Full documentation of all of the configuration options is here: [Finetune SDXL Config](../../reference/config/pipelines/sdxl_finetune.md)

!!! note "`save_checkpoint_format`"
    Note the `save_checkpoint_format` setting, as it is unique to full finetune training. For this tutorial, we have set `save_checkpoint_format: trained_only_diffusers`. This means that only the UNet model will be saved at each checkpoint, and it will be saved in diffusers format. This setting conserves disk space by not redundantly saving the non-trained weights. Before these UNet checkpoints can be used, they must either be merged into a full model, or extracted into a LoRA. Instructions for this follow later in this tutorial. A full explanation of the `save_checkpoint_format` options can be found here:  [save_checkpoint_format][invoke_training.pipelines.stable_diffusion_xl.finetune.config.SdxlFinetuneConfig.save_checkpoint_format].


## 3 - Start Training

Launch the training run.
```bash
# From inside the invoke-training/ source directory:
invoke-train -c src/invoke_training/sample_configs/sdxl_finetune_robocats_1x24gb.yaml
```

Training takes ~45 mins on an NVIDIA RTX 4090.

## 4 - Monitor

In a new terminal, launch Tensorboard to monitor the training run:
```bash
tensorboard --logdir output/
```
Access Tensorboard at [localhost:6006](http://localhost:6006) in your browser.

Sample images will be logged to Tensorboard so that you can see how the model is evolving.

Once training is complete, select the model checkpoint that produces the best visual results.

## 5 - Prepare the trained model

Since we set `save_checkpoint_format: trained_only_diffusers`, our selected checkpoint only contains the UNet model weights. The checkpoint has the following directory structure:

```bash
output/robocats/sdxl_finetune/1715373799.3558652/checkpoints/checkpoint-epoch_00000500-step_00002000/
└── unet
    ├── config.json
    └── diffusion_pytorch_model.safetensors
```

Before we can use this trained model, we must do one of the following:

- Prepare a full diffusers checkpoint with the new UNet weights.
- Extract the difference between the trained UNet and the original UNet into a LoRA model.

### Prepare a full model

If we want to use our finetuned UNet model, we must first package it into a format supported by applications like InvokeAI.

In this section we will assume that we have a full SDXL base model in diffusers format. It should have a directory structure like the one shown before. We simply need to replace the `unet/` directory with the one from our selected training checkpoint:
```bash
stable-diffusion-xl-base-1.0
├── model_index.json
├── scheduler
│   └── scheduler_config.json
├── text_encoder
│   ├── config.json
│   └── model.fp16.safetensors
├── text_encoder_2
│   ├── config.json
│   └── model.fp16.safetensors
├── tokenizer
│   ├── merges.txt
│   ├── special_tokens_map.json
│   ├── tokenizer_config.json
│   └── vocab.json
├── tokenizer_2
│   ├── merges.txt
│   ├── special_tokens_map.json
│   ├── tokenizer_config.json
│   └── vocab.json
├── unet # <-- Replace this directory with the trained checkpoint.
│   ├── config.json
│   └── diffusion_pytorch_model.fp16.safetensors
├── vae
│   ├── config.json
│   └── diffusion_pytorch_model.fp16.safetensors
└── vae_1_0
    └── diffusion_pytorch_model.fp16.safetensors
```

!!! note "diffusers variants (e.g. 'fp16')"
    In this example, notice that the `*.safetensors` files contain `.fp16.` in their filenames. Hugging Face refers to this identifier as a "variant". It is used to select between multiple model variants in their model hub.
    In this case, we should add the `.fp16.` variant tag to our finetuned UNet for consistency with the rest of the model. Since we set `save_dtype: float16` in our training config, the `fp16` tag accurately represents the precision of our UNet model file.

### Extract a LoRA model

An alternative to using the finetuned UNet model directly is to compare it against the original and extract the difference as a LoRA model. The resultant LoRA has a much smaller file size and can be applied to any base model. But, the LoRA model is a *lossy* representation of the difference, so some quality degradation is expected.

To extract a LoRA model, run the following command:
```bash
python src/invoke_training/model_merge/scripts/extract_lora_from_model_diff.py \
  --model-type SDXL \
  --model-orig path/to/stable-diffusion-xl-base-1.0 \
  --model-tuned output/robocats/sdxl_finetune/1715373799.3558652/checkpoints/checkpoint-epoch_00000500-step_00002000 \
  --save-to robocats_lora_step_2000.safetensors \
  --lora-rank 32
```

## 6 - Import into InvokeAI

If you haven't already, setup [InvokeAI](https://github.com/invoke-ai/InvokeAI) by following its documentation.

Import your finetuned diffusers model or your extracted LoRA from the 'Models' tab.

Congratulations, you can now use your new robocat model! 🎉

## 7 - Comparison: Finetune vs. LoRA Extraction

As noted earlier, the LoRA extraction process is lossy for a number of reasons.

Below, we compare images generated with the same seed and prompt for 3 different model configurations.

Prompt: *In robocat style, a robotic lion in the jungle.*

| SDXL Base 1.0 | w/ Finetuned UNet | w/ Extracted LoRA |
| - | - | - |
| ![Image generated with SDXL Base 1.0. Prompt: In robocat style, a robotic lion in the jungle.](../../images/robocats/lion_base.jpg) | ![Image generated with finetuned UNet. Prompt: In robocat style, a robotic lion in the jungle.](../../images/robocats/lion_finetuned.jpg) | ![Image generated with extracted LoRA. Prompt: In robocat style, a robotic lion in the jungle.](../../images/robocats/lion_extracted_lora.jpg)


================================================
FILE: docs/guides/stable_diffusion/textual_inversion_sdxl.md
================================================
# Textual Inversion - SDXL

This tutorial walks through a [Textual Inversion](https://arxiv.org/abs/2208.01618) training run with a [Stable Diffusion XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) base model.

## 1 - Dataset

For this tutorial, we'll use a dataset consisting of 4 images of Bruce the Gnome:

| | |
| - | - |
| ![bruce_the_gnome dataset image 1.](../../images/bruce_the_gnome/001.jpg) | ![bruce_the_gnome dataset image 2.](../../images/bruce_the_gnome/002.jpg) |
| ![bruce_the_gnome dataset image 3.](../../images/bruce_the_gnome/003.jpg) | ![bruce_the_gnome dataset image 4.](../../images/bruce_the_gnome/004.jpg) |

This sample dataset is included in the invoke-training repo under [sample_data/bruce_the_gnome](https://github.com/invoke-ai/invoke-training/tree/main/sample_data/bruce_the_gnome).

Here are a few tips for preparing a Textual Inversion dataset:

- Aim for 4 to 50 images of your concept (object / style). The optimal number depends on many factors, and can be much higher than this for some use cases.
- Vary all of the image features that you *don't* want your TI embedding to contain (e.g. background, pose, lighting, etc.).

## 2 - Configuration

Below is the training configuration that we'll use for this tutorial.

Raw config file: [src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml](https://github.com/invoke-ai/invoke-training/blob/main/src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml).

Full config reference docs: [Textual Inversion SDXL Config](../../reference/config/pipelines/sdxl_textual_inversion.md)

```yaml title="sdxl_textual_inversion_gnome_1x24gb.yaml"
--8<-- "src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml"
```

## 3 - Start Training

[Install invoke-training](../../get-started/installation.md), if you haven't already.

Launch the Textual Inversion training pipeline:
```bash
# From inside the invoke-training/ source directory:
invoke-train -c src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml
```

Training takes ~40 mins on an NVIDIA RTX 4090.

## 4 - Monitor

In a new terminal, launch Tensorboard to monitor the training run:
```bash
tensorboard --logdir output/
```
Access Tensorboard at [localhost:6006](http://localhost:6006) in your browser.

Sample images will be logged to Tensorboard so that you can see how the Textual Inversion embedding is evolving.

Once training is complete, select the epoch that produces the best visual results.

For this tutorial, we'll choose epoch 500:
![Screenshot of the Tensorboard UI showing the validation images for epoch 500.](../../images/tensorboard_bruce_the_gnome_epoch_500.png)
*Screenshot of the Tensorboard UI showing the validation images for epoch 500.*

## 5 - Transfer to InvokeAI

If you haven't already, setup [InvokeAI](https://github.com/invoke-ai/InvokeAI) by following its documentation.

Copy the selected TI embedding into your `${INVOKEAI_ROOT}/autoimport/embedding/` directory. For example:
```bash
cp output/sdxl_ti_bruce_the_gnome/1702587511.2273068/checkpoint_epoch-00000500.safetensors ${INVOKEAI_ROOT}/autoimport/embedding/bruce_the_gnome.safetensors
```

Note that we renamed the file to `bruce_the_gnome.safetensors`. You can choose any file name, but this will become the token used to reference your embedding. So, in our case, we can refer to our new embedding by including `<bruce_the_gnome>` in our prompts.

Launch Invoke AI and you can now use your new `bruce_the_gnome` TI embedding! 🎉

![Screenshot of the InvokeAI UI with an example of an image generated with the bruce_the_gnome TI embedding.](../../images/invokeai_bruce_the_gnome_ti.png)
*Example image generated with the prompt "`a photo of <bruce_the_gnome> at the park`".*


================================================
FILE: docs/index.md
================================================
# invoke-training

A library for training custom Stable Diffusion models (fine-tuning, LoRA training, textual inversion, etc.) that can be used in [InvokeAI](https://github.com/invoke-ai/InvokeAI).

## Documentation

The documentation is organized as follows:

- [Get Started](get-started/installation.md): Install `invoke-training` and run your first training pipeline.
- [Guides](guides/dataset_formats.md): Full tutorials for running popular training pipelines.
- [Config Reference](reference/config/index.md): Reference documentation for all supported training configuration options.
- [Contributing](contributing/development_environment.md): Information for `invoke-training` developers.


================================================
FILE: docs/reference/config/index.md
================================================
# Config Reference

This section contains reference documentation for the `invoke-training` configuration schema (i.e. documentation for all of the supported training options).

This documentation uses python typing semantics to define the configuration schema. Typically the configuration for a training run is specified in a YAML file and then parse against this schema.


================================================
FILE: docs/reference/config/pipelines/sd_lora.md
================================================
# `SdLoraConfig`

<!-- To control the member order, we first list out the members whose order we care about, then we list the rest. -->
::: invoke_training.pipelines.stable_diffusion.lora.config.SdLoraConfig
    options:
      members:
      - type

<!-- Note that we always hide "model_config", as it should not be set by the user. -->
::: invoke_training.pipelines.stable_diffusion.lora.config.SdLoraConfig
    options:
      filters:
      - "!^model_config"
      - "!^type"

================================================
FILE: docs/reference/config/pipelines/sd_textual_inversion.md
================================================
# `SdTextualInversionConfig`

<!-- To control the member order, we first list out the members whose order we care about, then we list the rest. -->
::: invoke_training.pipelines.stable_diffusion.textual_inversion.config.SdTextualInversionConfig
    options:
      members:
      - type

<!-- Note that we always hide "model_config", as it should not be set by the user. -->
::: invoke_training.pipelines.stable_diffusion.textual_inversion.config.SdTextualInversionConfig
    options:
      filters:
      - "!^model_config"
      - "!^type"


================================================
FILE: docs/reference/config/pipelines/sdxl_finetune.md
================================================
# `SdxlFinetuneConfig`

<!-- To control the member order, we first list out the members whose order we care about, then we list the rest. -->
::: invoke_training.pipelines.stable_diffusion_xl.finetune.config.SdxlFinetuneConfig
    options:
      members:
      - type

<!-- Note that we always hide "model_config", as it should not be set by the user. -->
::: invoke_training.pipelines.stable_diffusion_xl.finetune.config.SdxlFinetuneConfig
    options:
      filters:
      - "!^model_config"
      - "!^type"


================================================
FILE: docs/reference/config/pipelines/sdxl_lora.md
================================================
# `SdxlLoraConfig`

<!-- To control the member order, we first list out the members whose order we care about, then we list the rest. -->
::: invoke_training.pipelines.stable_diffusion_xl.lora.config.SdxlLoraConfig
    options:
      members:
      - type

<!-- Note that we always hide "model_config", as it should not be set by the user. -->
::: invoke_training.pipelines.stable_diffusion_xl.lora.config.SdxlLoraConfig
    options:
      filters:
      - "!^model_config"
      - "!^type"

================================================
FILE: docs/reference/config/pipelines/sdxl_lora_and_textual_inversion.md
================================================
# `SdxlLoraAndTextualInversionConfig`

<!-- To control the member order, we first list out the members whose order we care about, then we list the rest. -->
::: invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config.SdxlLoraAndTextualInversionConfig
    options:
      members:
      - type

<!-- Note that we always hide "model_config", as it should not be set by the user. -->
::: invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config.SdxlLoraAndTextualInversionConfig
    options:
      filters:
      - "!^model_config"
      - "!^type"

================================================
FILE: docs/reference/config/pipelines/sdxl_textual_inversion.md
================================================
# `SdxlTextualInversionConfig`

Below is a sample yaml config file for Textual Inversion SDXL training ([raw file](https://github.com/invoke-ai/invoke-training/blob/main/src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml)). All of the configuration fields are explained in detail on this page.

```yaml title="sdxl_textual_inversion_gnome_1x24gb.yaml"
--8<-- "src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml"
```

<!-- To control the member order, we first list out the members whose order we care about, then we list the rest. -->
::: invoke_training.pipelines.stable_diffusion_xl.textual_inversion.config.SdxlTextualInversionConfig
    options:
      members:
      - type

<!-- Note that we always hide "model_config", as it should not be set by the user. -->
::: invoke_training.pipelines.stable_diffusion_xl.textual_inversion.config.SdxlTextualInversionConfig
    options:
      filters:
      - "!^model_config"
      - "!^type"


================================================
FILE: docs/reference/config/shared/data/data_loader_config.md
================================================
::: invoke_training.config.data.data_loader_config
    options:
      filters:
      - "!^model_config"


================================================
FILE: docs/reference/config/shared/data/dataset_config.md
================================================
::: invoke_training.config.data.dataset_config
    options:
      filters:
      - "!^model_config"

================================================
FILE: docs/reference/config/shared/optimizer_config.md
================================================
::: invoke_training.config.optimizer.optimizer_config
    options:
      filters:
      - "!^model_config"


================================================
FILE: docs/templates/python/material/labels.html
================================================
<!--
    This file is intentionally empty. It overrides the default contents of
    https://github.com/mkdocstrings/python/blob/master/src/mkdocstrings_handlers/python/templates/material/labels.html
    to hide labels (class-attribute, instance-attribute, etc.)
-->


================================================
FILE: mkdocs.yml
================================================
site_name: invoke-training
site_url: https://invoke-ai.github.io/invoke-training/

repo_name: invoke-ai/invoke-training
repo_url: https://github.com/invoke-ai/invoke-training

theme:
  name: material
  features:
    - navigation.tabs
    - navigation.indexes
    - navigation.sections
    - content.code.copy

markdown_extensions:
  - admonition
  - sane_lists
  - pymdownx.highlight:
      anchor_linenums: true
      line_spans: __span
      pygments_lang_class: true
  - pymdownx.inlinehilite
  - pymdownx.snippets
  - pymdownx.superfences

nav:
  - Welcome: index.md
  - Get Started:
      - get-started/installation.md
      - get-started/quick-start.md
  - Guides:
      - Dataset Formats: guides/dataset_formats.md
      - Model Merging: guides/model_merge.md
      - Stable Diffusion Training:
          - guides/stable_diffusion/robocats_finetune_sdxl.md
          - guides/stable_diffusion/gnome_lora_masks_sdxl.md
          - guides/stable_diffusion/textual_inversion_sdxl.md
          - guides/stable_diffusion/dpo_lora_sd.md
  - YAML Config Reference:
      - reference/config/index.md
      - pipelines:
          - SD LoRA Config: reference/config/pipelines/sd_lora.md
          - SD Textual Inversion Config: reference/config/pipelines/sd_textual_inversion.md
          - SDXL LoRA Config: reference/config/pipelines/sdxl_lora.md
          - SDXL Textual Inversion Config: reference/config/pipelines/sdxl_textual_inversion.md
          - SDXL LoRA and Textual Inversion Config: reference/config/pipelines/sdxl_lora_and_textual_inversion.md
          - SDXL Finetune Config: reference/config/pipelines/sdxl_finetune.md
      - shared:
          - data_loader_config: reference/config/shared/data/data_loader_config.md
          - dataset_config: reference/config/shared/data/dataset_config.md
          - optimizer_config: reference/config/shared/optimizer_config.md
  - Contributing:
      - contributing/development_environment.md
      - contributing/directory_structure.md
      - contributing/tests.md
      - contributing/documentation.md

plugins:
  - search
  - mkdocstrings:
      default_handler: python
      custom_templates: docs/templates
      handlers:
        python:
          options:
            show_root_heading: false
            show_root_toc_entry: false
            show_bases: false
            show_source: false
            show_if_no_docstring: true
            inherited_members: true
            annotations_path: brief
            separate_signature: true
            show_signature_annotations: true
            members_order: source


================================================
FILE: pyproject.toml
================================================
[build-system]
requires = ["setuptools>=65.5", "pip>=22.3"]
build-backend = "setuptools.build_meta"

[project]
name = "invoke-training"
version = "0.0.1"
authors = [{ name = "The Invoke AI Team", email = "ryan@invoke.ai" }]
description = "A library for Stable Diffusion model training."
readme = "README.md"
requires-python = ">=3.10"
license = { text = "Apache-2.0" }
classifiers = [
    "Programming Language :: Python :: 3",
    "Operating System :: OS Independent",
]
dependencies = [
    "accelerate",
    "datasets~=2.14.3",
    "diffusers[torch]",
    "einops",
    "fastapi",
    "gradio",
    "invokeai>=5.10.0a1",
    "numpy<2.0.0",
    "omegaconf",
    "peft~=0.11.1",
    "pillow",
    "prodigyopt",
    "pydantic",
    "pyyaml",
    "safetensors",
    "tensorboard",
    "torch",
    "torchvision",
    "tqdm",
    "transformers",
    "uvicorn[standard]",
]

[project.optional-dependencies]
"xformers" = ["xformers>=0.0.28.post1; sys_platform!='darwin'"]
"bitsandbytes" = ["bitsandbytes>=0.43.1; sys_platform!='darwin'"]

"test" = [
    "mkdocs",
    "mkdocs-material",
    "mkdocstrings[python]",
    "pre-commit~=3.3.3",
    "pytest~=7.4.0",
    "ruff~=0.11.2",
    "ruff-lsp",
]

[project.scripts]
"invoke-train" = "invoke_training.scripts.invoke_train:main"
"invoke-train-ui" = "invoke_training.scripts.invoke_train_ui:main"
"invoke-generate-images" = "invoke_training.scripts.invoke_generate_images:main"
"invoke-visualize-data-loading" = "invoke_training.scripts.invoke_visualize_data_loading:main"

[project.urls]
"Homepage" = "https://github.com/invoke-ai/invoke-training"
"Discord" = "https://discord.gg/ZmtBAhwWhy"

[tool.setuptools.package-data]
"invoke_training.assets" = ["*.png"]
"invoke_training.sample_configs" = ["**/*.yaml"]
"invoke_training.ui" = ["*.html"]

[tool.ruff]
src = ["src"]
lint.select = ["E", "F", "W", "C9", "N8", "I"]
target-version = "py39"
line-length = 120

[tool.pytest.ini_options]
addopts = "--strict-markers"
markers = [
    "cuda: marks tests that require a CUDA GPU",
    "loads_model: marks tests that require a model (or data) from the HF hub",
]


================================================
FILE: sample_data/bruce_the_gnome/data.jsonl
================================================
{"image": "001.png", "text": "A stuffed gnome sits on a wooden floor, facing right with a gray couch in the background."}
{"image": "002.png", "text": "A stuffed gnome stands on a black tiled floor, with a silver refrigerator and white wall in the background."}
{"image": "004.png", "text": "A stuffed gnome sits on a white marble floor, photorealistic."}
{"image": "003.png", "text": "A stuffed gnome sits on a gray tiled floor, facing the camera."}


================================================
FILE: src/invoke_training/__init__.py
================================================


================================================
FILE: src/invoke_training/_shared/__init__.py
================================================


================================================
FILE: src/invoke_training/_shared/accelerator/__init__.py
================================================


================================================
FILE: src/invoke_training/_shared/accelerator/accelerator_utils.py
================================================
import logging
import os
from typing import Literal

import datasets
import diffusers
import torch
import transformers
from accelerate import Accelerator
from accelerate.logging import MultiProcessAdapter, get_logger
from accelerate.utils import ProjectConfiguration


def initialize_accelerator(
    out_dir: str, gradient_accumulation_steps: int, mixed_precision: str, log_with: str
) -> Accelerator:
    """Configure Hugging Face accelerate and return an Accelerator.

    Args:
        out_dir (str): The output directory where results will be written.
        gradient_accumulation_steps (int): Forwarded to accelerat.Accelerator(...).
        mixed_precision (str): Forwarded to accelerate.Accelerator(...).
        log_with (str): Forwarded to accelerat.Accelerator(...)

    Returns:
        Accelerator
    """
    accelerator_project_config = ProjectConfiguration(
        project_dir=out_dir,
        logging_dir=os.path.join(out_dir, "logs"),
    )
    return Accelerator(
        project_config=accelerator_project_config,
        gradient_accumulation_steps=gradient_accumulation_steps,
        mixed_precision=mixed_precision,
        log_with=log_with,
    )


def initialize_logging(logger_name: str, accelerator: Accelerator) -> MultiProcessAdapter:
    """Configure logging.

    Returns an accelerate logger with multi-process logging support. Logging is configured to be more verbose on the
    main process. Non-main processes only log at error level for Hugging Face libraries (datasets, transformers,
    diffusers).

    Args:
        accelerator (Accelerator): The Accelerator to configure.

    Returns:
        MultiProcessAdapter: _description_
    """
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    if accelerator.is_local_main_process:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_warning()
        diffusers.utils.logging.set_verbosity_info()
    else:
        # Only log errors from non-main processes.
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()
        diffusers.utils.logging.set_verbosity_error()

    return get_logger(logger_name)


def get_mixed_precision_dtype(accelerator: Accelerator):
    """Extract torch.dtype from Accelerator config.

    Args:
        accelerator (Accelerator): The Hugging Face Accelerator.

    Raises:
        NotImplementedError: If the accelerator's mixed_precision configuration is not recognized.

    Returns:
        torch.dtype: The weight type inferred from the accelerator mixed_precision configuration.
    """
    weight_dtype: torch.dtype = torch.float32
    if accelerator.mixed_precision is None or accelerator.mixed_precision == "no":
        weight_dtype = torch.float32
    elif accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16
    else:
        raise NotImplementedError(f"mixed_precision mode '{accelerator.mixed_precision}' is not yet supported.")
    return weight_dtype


def get_dtype_from_str(dtype_str: Literal["float16", "bfloat16", "float32"]) -> torch.dtype:
    if dtype_str == "float16":
        return torch.float16
    elif dtype_str == "bfloat16":
        return torch.bfloat16
    elif dtype_str == "float32":
        return torch.float32
    else:
        raise ValueError(f"Unsupported dtype: {dtype_str}")


================================================
FILE: src/invoke_training/_shared/checkpoints/__init__.py
================================================


================================================
FILE: src/invoke_training/_shared/checkpoints/checkpoint_tracker.py
================================================
import os
import shutil
import typing


class CheckpointTracker:
    """A utility class for managing checkpoint paths.

    Manages checkpoint paths of the following forms:
    - Checkpoint directories: `{base_dir}/{prefix}-epoch_{num_epochs}-step_{num_steps}`
    - Checkpoint files: `{base_dir}/{prefix}-epoch_{num_epochs}-step_{num_steps}{extension}`
    """

    def __init__(
        self,
        base_dir: str,
        prefix: str,
        extension: typing.Optional[str] = None,
        max_checkpoints: typing.Optional[int] = None,
        index_padding: int = 8,
    ):
        """Initialize a CheckpointTracker.

        Args:
            base_dir (str): The base checkpoint directory.
            prefix (str): A prefix applied to every checkpoint.
            extension (str, optional): If set, this is the file extension that will be applied to all checkpoints
                (usually one of ".pt", ".ckpt", or ".safetensors"). If None, then it will be assumed that we are
                managing checkpoint directories rather than files.
            max_checkpoints (typing.Optional[int], optional): The maximum number of checkpoints that should exist in
                base_dir.
            index_padding (int, optional): The length of the zero-padded epoch/step counts in the generated checkpoint
                names. E.g. index_padding=8 would produce checkpoint paths like
                "base_dir/prefix-epoch_00000001-step_00000001.ckpt".

        Raises:
            ValueError: If extension is provided, but it doesn not start with a '.'.
        """
        if extension is not None and not extension.startswith("."):
            raise ValueError(f"extension='{extension}' must start with a '.'.")

        self._base_dir = base_dir
        self._prefix = prefix
        self._extension = extension
        self._max_checkpoints = max_checkpoints
        self._index_padding = index_padding

    def prune(self, buffer_num: int = 1) -> int:
        """Delete checkpoint files and directories so that there are at most `max_checkpoints - buffer_num` checkpoints
        remaining. The checkpoints with the lowest step counts will be deleted.

        Args:
            buffer_num (int, optional): The number below `max_checkpoints` to 'free-up'.

        Returns:
            int: The number of checkpoints deleted.
        """
        if self._max_checkpoints is None:
            return 0

        checkpoints = os.listdir(self._base_dir)
        checkpoints = [p for p in checkpoints if p.startswith(self._prefix)]
        checkpoints = sorted(
            checkpoints,
            key=lambda x: int(os.path.splitext(x)[0].split("-step_")[-1]),
        )

        num_to_remove = len(checkpoints) - (self._max_checkpoints - buffer_num)
        if num_to_remove > 0:
            checkpoints_to_remove = checkpoints[:num_to_remove]

            for checkpoint_to_remove in checkpoints_to_remove:
                checkpoint_to_remove = os.path.join(self._base_dir, checkpoint_to_remove)
                if os.path.isfile(checkpoint_to_remove):
                    # Delete checkpoint file.
                    os.remove(checkpoint_to_remove)
                else:
                    # Delete checkpoint directory.
                    shutil.rmtree(checkpoint_to_remove)

        return max(0, num_to_remove)

    def get_path(self, epoch: int, step: int) -> str:
        """Get the checkpoint path for index `idx`.

        Args:
            epoch (int): The number of completed epochs.
            step (int): The number of completed training steps.

        Returns:
            str: The checkpoint path.
        """
        suffix = self._extension or ""
        return os.path.join(
            self._base_dir,
            f"{self._prefix.strip()}-epoch_{epoch:0>{self._index_padding}}-step_{step:0>{self._index_padding}}{suffix}",
        )


================================================
FILE: src/invoke_training/_shared/checkpoints/lora_checkpoint_utils.py
================================================
from pathlib import Path

import peft
import torch


def save_multi_model_peft_checkpoint(checkpoint_dir: Path | str, models: dict[str, peft.PeftModel]):
    """Save a dict of PeftModels to a checkpoint directory.

    The `models` dict keys are used as the subdirectories for each individual model.

    `load_multi_model_peft_checkpoint(...)` can be used to load the resultant checkpoint.
    """
    checkpoint_dir = Path(checkpoint_dir)
    for model_key, peft_model in models.items():
        assert isinstance(peft_model, peft.PeftModel)

        # HACK(ryand): PeftModel.save_pretrained(...) expects the config to have a "_name_or_path" entry. For now, we
        # set this to None here. This should be fixed upstream in PEFT.
        if (
            hasattr(peft_model, "config")
            and isinstance(peft_model.config, dict)
            and "_name_or_path" not in peft_model.config
        ):
            peft_model.config["_name_or_path"] = None

        peft_model.save_pretrained(str(checkpoint_dir / model_key))


def load_multi_model_peft_checkpoint(
    checkpoint_dir: Path | str,
    models: dict[str, torch.nn.Module],
    is_trainable: bool = False,
    raise_if_subdir_missing: bool = True,
) -> dict[str, torch.nn.Module]:
    """Load a multi-model PEFT checkpoint that was saved with `save_multi_model_peft_checkpoint(...)`."""
    checkpoint_dir = Path(checkpoint_dir)
    assert checkpoint_dir.exists()

    out_models = {}
    for model_key, model in models.items():
        dir_path: Path = checkpoint_dir / model_key
        if dir_path.exists():
            out_models[model_key] = peft.PeftModel.from_pretrained(model, dir_path, is_trainable=is_trainable)
        else:
            if raise_if_subdir_missing:
                raise ValueError(f"'{dir_path}' does not exist.")
            else:
                # Pass through the model unchanged.
                out_models[model_key] = model

    return out_models


# This implementation is based on
# https://github.com/huggingface/peft/blob/8665e2b5719faa4e4b91749ddec09442927b53e0/examples/lora_dreambooth/convert_peft_sd_lora_to_kohya_ss.py#L20
def _convert_peft_state_dict_to_kohya_state_dict(
    lora_config: peft.LoraConfig,
    peft_state_dict: dict[str, torch.Tensor],
    prefix: str,
    dtype: torch.dtype,
) -> dict[str, torch.Tensor]:
    kohya_ss_state_dict = {}
    for peft_key, weight in peft_state_dict.items():
        kohya_key = peft_key.replace("base_model.model", prefix)
        kohya_key = kohya_key.replace("lora_A", "lora_down")
        kohya_key = kohya_key.replace("lora_B", "lora_up")
        kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
        kohya_ss_state_dict[kohya_key] = weight.to(dtype)

        # Set alpha parameter
        if "lora_down" in kohya_key:
            alpha_key = f"{kohya_key.split('.')[0]}.alpha"
            kohya_ss_state_dict[alpha_key] = torch.tensor(lora_config.lora_alpha).to(dtype)

    return kohya_ss_state_dict


def _convert_peft_models_to_kohya_state_dict(
    kohya_prefixes: list[str], models: list[peft.PeftModel]
) -> dict[str, torch.Tensor]:
    kohya_state_dict = {}
    default_adapter_name = "default"

    for kohya_prefix, peft_model in zip(kohya_prefixes, models, strict=True):
        lora_config = peft_model.peft_config[default_adapter_name]
        assert isinstance(lora_config, peft.LoraConfig)

        peft_state_dict = peft.get_peft_model_state_dict(peft_model, adapter_name=default_adapter_name)

        kohya_state_dict.update(
            _convert_peft_state_dict_to_kohya_state_dict(
                lora_config=lora_config,
                peft_state_dict=peft_state_dict,
                prefix=kohya_prefix,
                dtype=torch.float32,
            )
        )

    return kohya_state_dict


================================================
FILE: src/invoke_training/_shared/checkpoints/serialization.py
================================================
import typing
from pathlib import Path

import safetensors.torch
import torch


def save_state_dict(state_dict: typing.Dict[str, torch.Tensor], out_file: typing.Union[Path, str]):
    """Save a state_dict to a file.

    Both safetensors and torch formats are supported. The format is inferred from the `out_file` extension.
    Supported extensions:
    - ".ckpt" -> torch
    - ".pt" -> torch
    - ".safetensors -> safetensors

    Args:
        state_dict (typing.Dict[str, torch.Tensor]): The state_dict to save.
        out_file (Path | str): The output file to save to.

    Raises:
        ValueError: If the `out_file` has an unsupported file extension.
    """
    out_file = Path(out_file)
    if out_file.suffix == ".ckpt" or out_file.suffix == ".pt":
        torch.save(state_dict, out_file)
    elif out_file.suffix == ".safetensors":
        safetensors.torch.save_file(state_dict, out_file)
    else:
        raise ValueError(f"Unsupported file extension: '{out_file.suffix}'.")


def load_state_dict(in_file: typing.Union[Path, str]) -> typing.Dict[str, torch.Tensor]:
    """Load a state_dict from a file.

    Both safetensors and torch formats are supported. The format is inferred from the `in_file` extension.
    Supported extensions:
    - ".ckpt" -> torch
    - ".pt" -> torch
    - ".safetensors -> safetensors

    Args:
        in_file (Path | str): The input file to load from.

    Raises:
        ValueError: If the `in_file` has an unsupported file extension.

    Returns:
        typing.Dict[str, torch.Tensor]: The loaded state_dict.
    """
    in_file = Path(in_file)
    if in_file.suffix == ".ckpt" or in_file.suffix == ".pt":
        return torch.load(in_file)
    elif in_file.suffix == ".safetensors":
        return safetensors.torch.load_file(in_file)
    else:
        raise ValueError(f"Unsupported file extension: '{in_file.suffix}'.")


================================================
FILE: src/invoke_training/_shared/data/ARCHITECTURE.md
================================================
# Dataset Architecture
Dataset handling is split into 3 layers of abstraction: Datasets, Transforms, and DataLoaders. Each is explained in more detail below.

## Datasets

Datasets implement the [torch.utils.data.Dataset](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#creating-a-custom-dataset-for-your-files) interface.

Most dataset classes act as an abstraction over a specific dataset format.
 
## Transforms

Transforms are functions applied to data loaded by Datasets. For example, the `SDImageTransform` implements image augmentations for Stable Diffusion training.

Transforms are kept separate from the underlying datasets for several reasons:
- It is easier to write tests for isolated transforms.
- Modular transforms can often be re-used for multiple base datasets.
- Modular transforms make it easy to customize datasets for different situations. For example, you may want to wrap a dataset with one set of transforms initially to populate a cache, and then with a different set of transforms to read from the cache.

Transforms are applied to a dataset via the `TransformDataset` class.

## DataLoaders

The dataset classes (with composed transforms) are wrapped in a `torch.utils.data.DataLoader` that handles batch collation, multi-processing, etc.


================================================
FILE: src/invoke_training/_shared/data/__init__.py
================================================


================================================
FILE: src/invoke_training/_shared/data/data_loaders/__init__.py
================================================


================================================
FILE: src/invoke_training/_shared/data/data_loaders/dreambooth_sd_dataloader.py
================================================
import typing

from torch.utils.data import ConcatDataset, DataLoader
from torch.utils.data.sampler import RandomSampler, SequentialSampler

from invoke_training._shared.data.data_loaders.image_caption_sd_dataloader import (
    build_aspect_ratio_bucket_manager,
    sd_image_caption_collate_fn,
)
from invoke_training._shared.data.datasets.image_dir_dataset import ImageDirDataset
from invoke_training._shared.data.datasets.transform_dataset import TransformDataset
from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import AspectRatioBucketBatchSampler
from invoke_training._shared.data.samplers.batch_offset_sampler import BatchOffsetSampler
from invoke_training._shared.data.samplers.concat_sampler import ConcatSampler
from invoke_training._shared.data.samplers.interleaved_sampler import InterleavedSampler
from invoke_training._shared.data.samplers.offset_sampler import OffsetSampler
from invoke_training._shared.data.transforms.constant_field_transform import ConstantFieldTransform
from invoke_training._shared.data.transforms.drop_field_transform import DropFieldTransform
from invoke_training._shared.data.transforms.load_cache_transform import LoadCacheTransform
from invoke_training._shared.data.transforms.sd_image_transform import SDImageTransform
from invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache
from invoke_training.config.data.data_loader_config import DreamboothSDDataLoaderConfig


def build_dreambooth_sd_dataloader(
    config: DreamboothSDDataLoaderConfig,
    batch_size: int,
    text_encoder_output_cache_dir: typing.Optional[str] = None,
    text_encoder_cache_field_to_output_field: typing.Optional[dict[str, str]] = None,
    vae_output_cache_dir: typing.Optional[str] = None,
    shuffle: bool = True,
    sequential_batching: bool = False,
) -> DataLoader:
    """Construct a DataLoader for a DreamBooth dataset for Stable Diffusion XL.

    Args:
        config (DreamboothSDDataLoaderConfig):
        batch_size (int):
        text_encoder_output_cache_dir (str, optional): The directory where text encoder outputs are cached and should be
            loaded from.
        vae_output_cache_dir (str, optional): The directory where VAE outputs are cached and should be loaded from. If
            set, then the image augmentation transforms will be skipped, and the image will not be copied to VRAM.
        shuffle (bool, optional): Whether to shuffle the dataset order.
        sequential_batching (bool, optional): If True, the internal dataset will be processed sequentially rather than
            interleaving class and instance examples. This is intended to be used when processing the entire dataset for
            caching purposes. Defaults to False.

    Returns:
        DataLoader
    """
    # Prepare instance dataset.
    base_instance_dataset = ImageDirDataset(
        config.instance_dataset.dataset_dir,
        id_prefix="instance_",
        keep_in_memory=config.instance_dataset.keep_in_memory,
    )
    instance_dataset = TransformDataset(
        base_instance_dataset,
        [
            ConstantFieldTransform("caption", config.instance_caption),
            ConstantFieldTransform("loss_weight", 1.0),
        ],
    )
    datasets = [instance_dataset]

    # Prepare class dataset.
    base_class_dataset = None
    class_dataset = None
    if config.class_dataset is not None:
        base_class_dataset = ImageDirDataset(
            config.class_dataset.dataset_dir, id_prefix="class_", keep_in_memory=config.class_dataset.keep_in_memory
        )
        class_dataset = TransformDataset(
            base_class_dataset,
            [
                ConstantFieldTransform("caption", config.class_caption),
                ConstantFieldTransform("loss_weight", config.class_data_loss_weight),
            ],
        )
        datasets.append(class_dataset)

    # Merge instance dataset and class dataset.
    merged_dataset = ConcatDataset(datasets)

    # Initialize either the fixed target resolution or aspect ratio buckets.
    target_resolution = None
    aspect_ratio_bucket_manager = None
    instance_sampler = None
    class_sampler = None
    if config.aspect_ratio_buckets is None:
        target_resolution = config.resolution
        # TODO(ryand): Provide a seeded generator.
        instance_sampler = RandomSampler(instance_dataset) if shuffle else SequentialSampler(instance_dataset)
        if base_class_dataset is not None:
            class_sampler = RandomSampler(class_dataset) if shuffle else SequentialSampler(class_dataset)
            class_sampler = OffsetSampler(class_sampler, offset=len(base_instance_dataset))
    else:
        aspect_ratio_bucket_manager = build_aspect_ratio_bucket_manager(config=config.aspect_ratio_buckets)
        # TODO(ryand): Drill-down the seed parameter rather than hard-coding to 0 here.
        instance_sampler = AspectRatioBucketBatchSampler.from_image_sizes(
            bucket_manager=aspect_ratio_bucket_manager,
            image_sizes=base_instance_dataset.get_image_dimensions(),
            batch_size=batch_size,
            shuffle=shuffle,
            seed=0,
        )
        if base_class_dataset is not None:
            class_sampler = AspectRatioBucketBatchSampler.from_image_sizes(
                bucket_manager=aspect_ratio_bucket_manager,
                image_sizes=base_class_dataset.get_image_dimensions(),
                batch_size=batch_size,
                shuffle=shuffle,
                seed=0,
            )
            class_sampler = BatchOffsetSampler(class_sampler, offset=len(base_instance_dataset))

    # Add transforms to the merged dataset.
    all_transforms = []
    if vae_output_cache_dir is None:
        all_transforms.append(
            SDImageTransform(
                image_field_names=["image"],
                fields_to_normalize_to_range_minus_one_to_one=["image"],
                resolution=target_resolution,
                aspect_ratio_bucket_manager=aspect_ratio_bucket_manager,
                center_crop=config.center_crop,
                random_flip=config.random_flip,
            )
        )
    else:
        vae_cache = TensorDiskCache(vae_output_cache_dir)
        all_transforms.append(
            LoadCacheTransform(
                cache=vae_cache,
                cache_key_field="id",
                cache_field_to_output_field={
                    "vae_output": "vae_output",
                    "original_size_hw": "original_size_hw",
                    "crop_top_left_yx": "crop_top_left_yx",
                },
            )
        )
        # We drop the image to avoid having to either convert from PIL, or handle PIL batch collation.
        all_transforms.append(DropFieldTransform("image"))

    if text_encoder_output_cache_dir is not None:
        assert text_encoder_cache_field_to_output_field is not None
        text_encoder_cache = TensorDiskCache(text_encoder_output_cache_dir)
        all_transforms.append(
            LoadCacheTransform(
                cache=text_encoder_cache,
                cache_key_field="id",
                cache_field_to_output_field=text_encoder_cache_field_to_output_field,
            )
        )

    merged_dataset = TransformDataset(merged_dataset, all_transforms)

    # Choose between sequential vs. interleaved merging of the instance and class samplers.
    # Sequential sampling is typically used to populate a cache, because it guarantees that all examples will be
    # included in an epoch.
    samplers = [instance_sampler]
    if class_sampler is not None:
        samplers.append(class_sampler)
    if sequential_batching:
        sampler = ConcatSampler(samplers)
    else:
        sampler = InterleavedSampler(samplers)

    if config.aspect_ratio_buckets is None:
        return DataLoader(
            merged_dataset,
            sampler=sampler,
            collate_fn=sd_image_caption_collate_fn,
            batch_size=batch_size,
            num_workers=config.dataloader_num_workers,
        )
    else:
        # If config.aspect_ratio_buckets is not None, then we are using a batch sampler.
        return DataLoader(
            merged_dataset,
            batch_sampler=sampler,
            collate_fn=sd_image_caption_collate_fn,
            num_workers=config.dataloader_num_workers,
        )


================================================
FILE: src/invoke_training/_shared/data/data_loaders/image_caption_flux_dataloader.py
================================================
import typing

from torch.utils.data import DataLoader

from invoke_training._shared.data.data_loaders.image_caption_sd_dataloader import (
    build_aspect_ratio_bucket_manager,
)
from invoke_training._shared.data.data_loaders.image_caption_sd_dataloader import (
    sd_image_caption_collate_fn as flux_image_caption_collate_fn,
)
from invoke_training._shared.data.datasets.build_dataset import (
    build_hf_hub_image_caption_dataset,
    build_image_caption_dir_dataset,
    build_image_caption_jsonl_dataset,
)
from invoke_training._shared.data.datasets.transform_dataset import TransformDataset
from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import (
    AspectRatioBucketBatchSampler,
)
from invoke_training._shared.data.transforms.caption_prefix_transform import CaptionPrefixTransform
from invoke_training._shared.data.transforms.drop_field_transform import DropFieldTransform
from invoke_training._shared.data.transforms.flux_image_transform import FluxImageTransform
from invoke_training._shared.data.transforms.load_cache_transform import LoadCacheTransform
from invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache
from invoke_training.config.data.data_loader_config import ImageCaptionFluxDataLoaderConfig
from invoke_training.config.data.dataset_config import (
    HFHubImageCaptionDatasetConfig,
    ImageCaptionDirDatasetConfig,
    ImageCaptionJsonlDatasetConfig,
)


def build_image_caption_flux_dataloader(  # noqa: C901
    config: ImageCaptionFluxDataLoaderConfig,
    batch_size: int,
    use_masks: bool = False,
    text_encoder_output_cache_dir: typing.Optional[str] = None,
    text_encoder_cache_field_to_output_field: typing.Optional[dict[str, str]] = None,
    vae_output_cache_dir: typing.Optional[str] = None,
    shuffle: bool = True,
) -> DataLoader:
    """Construct a DataLoader for an image-caption dataset for Flux.1-dev.

    Args:
        config (ImageCaptionFluxDataLoaderConfig): The dataset config.
        batch_size (int): The DataLoader batch size.
        text_encoder_output_cache_dir (str, optional): The directory where text encoder outputs are cached and should be
            loaded from. If set, then the TokenizeTransform will not be applied.
        vae_output_cache_dir (str, optional): The directory where VAE outputs are cached and should be loaded from. If
            set, then the image augmentation transforms will be skipped, and the image will not be copied to VRAM.
        shuffle (bool, optional): Whether to shuffle the dataset order.
    Returns:
        DataLoader
    """
    if isinstance(config.dataset, HFHubImageCaptionDatasetConfig):
        base_dataset = build_hf_hub_image_caption_dataset(config.dataset)
    elif isinstance(config.dataset, ImageCaptionJsonlDatasetConfig):
        base_dataset = build_image_caption_jsonl_dataset(config.dataset)
    elif isinstance(config.dataset, ImageCaptionDirDatasetConfig):
        base_dataset = build_image_caption_dir_dataset(config.dataset)
    else:
        raise ValueError(f"Unexpected dataset config type: '{type(config.dataset)}'.")

    # Initialize either the fixed target resolution or aspect ratio buckets.
    if config.aspect_ratio_buckets is None:
        aspect_ratio_bucket_manager = None
        batch_sampler = None
    else:
        aspect_ratio_bucket_manager = build_aspect_ratio_bucket_manager(config=config.aspect_ratio_buckets)
        # TODO(ryand): Drill-down the seed parameter rather than hard-coding to 0 here.
        batch_sampler = AspectRatioBucketBatchSampler.from_image_sizes(
            bucket_manager=aspect_ratio_bucket_manager,
            image_sizes=base_dataset.get_image_dimensions(),
            batch_size=batch_size,
            shuffle=shuffle,
            seed=0,
        )

    all_transforms = []

    if config.caption_prefix is not None:
        all_transforms.append(CaptionPrefixTransform(caption_field_name="caption", prefix=config.caption_prefix + " "))

    if vae_output_cache_dir is None:
        image_field_names = ["image"]
        if use_masks:
            image_field_names.append("mask")
        else:
            all_transforms.append(DropFieldTransform("mask"))

        all_transforms.append(
            FluxImageTransform(
                image_field_names=image_field_names,
                fields_to_normalize_to_range_minus_one_to_one=["image"],
                resolution=config.resolution,
                aspect_ratio_bucket_manager=aspect_ratio_bucket_manager,
                center_crop=config.center_crop,
                random_flip=config.random_flip,
            )
        )
    else:
        # We drop the image to avoid having to either convert from PIL, or handle PIL batch collation.
        all_transforms.append(DropFieldTransform("image"))
        all_transforms.append(DropFieldTransform("mask"))

        vae_cache = TensorDiskCache(vae_output_cache_dir)

        cache_field_to_output_field = {
            "vae_output": "vae_output",
            "original_size_hw": "original_size_hw",
            "crop_top_left_yx": "crop_top_left_yx",
        }
        if use_masks:
            cache_field_to_output_field["mask"] = "mask"
        all_transforms.append(
            LoadCacheTransform(
                cache=vae_cache,
                cache_key_field="id",
                cache_field_to_output_field=cache_field_to_output_field,
            )
        )

    if text_encoder_output_cache_dir is not None:
        assert text_encoder_cache_field_to_output_field is not None
        text_encoder_cache = TensorDiskCache(text_encoder_output_cache_dir)
        all_transforms.append(
            LoadCacheTransform(
                cache=text_encoder_cache,
                cache_key_field="id",
                cache_field_to_output_field=text_encoder_cache_field_to_output_field,
            )
        )
    dataset = TransformDataset(base_dataset, all_transforms)

    if batch_sampler is None:
        return DataLoader(
            dataset,
            shuffle=shuffle,
            collate_fn=flux_image_caption_collate_fn,
            batch_size=batch_size,
            num_workers=config.dataloader_num_workers,
        )
    else:
        return DataLoader(
            dataset,
            batch_sampler=batch_sampler,
            collate_fn=flux_image_caption_collate_fn,
            num_workers=config.dataloader_num_workers,
        )


================================================
FILE: src/invoke_training/_shared/data/data_loaders/image_caption_sd_dataloader.py
================================================
import typing

import torch
from torch.utils.data import DataLoader

from invoke_training._shared.data.datasets.build_dataset import (
    build_hf_hub_image_caption_dataset,
    build_image_caption_dir_dataset,
    build_image_caption_jsonl_dataset,
)
from invoke_training._shared.data.datasets.transform_dataset import TransformDataset
from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import AspectRatioBucketBatchSampler
from invoke_training._shared.data.transforms.caption_prefix_transform import CaptionPrefixTransform
from invoke_training._shared.data.transforms.drop_field_transform import DropFieldTransform
from invoke_training._shared.data.transforms.load_cache_transform import LoadCacheTransform
from invoke_training._shared.data.transforms.sd_image_transform import SDImageTransform
from invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache
from invoke_training._shared.data.utils.aspect_ratio_bucket_manager import AspectRatioBucketManager
from invoke_training.config.data.data_loader_config import AspectRatioBucketConfig, ImageCaptionSDDataLoaderConfig
from invoke_training.config.data.dataset_config import (
    HFHubImageCaptionDatasetConfig,
    ImageCaptionDirDatasetConfig,
    ImageCaptionJsonlDatasetConfig,
)


def sd_image_caption_collate_fn(examples):
    """A batch collation function for the image-caption SDXL data loader."""
    out_examples = {
        "id": [example["id"] for example in examples],
    }

    if "image" in examples[0]:
        out_examples["image"] = torch.stack([example["image"] for example in examples])

    if "original_size_hw" in examples[0]:
        out_examples["original_size_hw"] = [example["original_size_hw"] for example in examples]

    if "crop_top_left_yx" in examples[0]:
        out_examples["crop_top_left_yx"] = [example["crop_top_left_yx"] for example in examples]

    if "caption" in examples[0]:
        out_examples["caption"] = [example["caption"] for example in examples]

    if "loss_weight" in examples[0]:
        out_examples["loss_weight"] = torch.tensor([example["loss_weight"] for example in examples])

    if "prompt_embeds" in examples[0]:
        out_examples["prompt_embeds"] = torch.stack([example["prompt_embeds"] for example in examples])
        out_examples["pooled_prompt_embeds"] = torch.stack([example["pooled_prompt_embeds"] for example in examples])

    if "text_encoder_output" in examples[0]:
        out_examples["text_encoder_output"] = torch.stack([example["text_encoder_output"] for example in examples])

    if "vae_output" in examples[0]:
        out_examples["vae_output"] = torch.stack([example["vae_output"] for example in examples])

    if "mask" in examples[0]:
        out_examples["mask"] = torch.stack([example["mask"] for example in examples])

    return out_examples


def build_aspect_ratio_bucket_manager(config: AspectRatioBucketConfig):
    return AspectRatioBucketManager.from_constraints(
        target_resolution=config.target_resolution,
        start_dim=config.start_dim,
        end_dim=config.end_dim,
        divisible_by=config.divisible_by,
    )


def build_image_caption_sd_dataloader(  # noqa: C901
    config: ImageCaptionSDDataLoaderConfig,
    batch_size: int,
    use_masks: bool = False,
    text_encoder_output_cache_dir: typing.Optional[str] = None,
    text_encoder_cache_field_to_output_field: typing.Optional[dict[str, str]] = None,
    vae_output_cache_dir: typing.Optional[str] = None,
    shuffle: bool = True,
) -> DataLoader:
    """Construct a DataLoader for an image-caption dataset for Stable Diffusion XL.

    Args:
        config (ImageCaptionSDDataLoaderConfig): The dataset config.
        batch_size (int): The DataLoader batch size.
        text_encoder_output_cache_dir (str, optional): The directory where text encoder outputs are cached and should be
            loaded from. If set, then the TokenizeTransform will not be applied.
        vae_output_cache_dir (str, optional): The directory where VAE outputs are cached and should be loaded from. If
            set, then the image augmentation transforms will be skipped, and the image will not be copied to VRAM.
        shuffle (bool, optional): Whether to shuffle the dataset order.
    Returns:
        DataLoader
    """
    if isinstance(config.dataset, HFHubImageCaptionDatasetConfig):
        base_dataset = build_hf_hub_image_caption_dataset(config.dataset)
    elif isinstance(config.dataset, ImageCaptionJsonlDatasetConfig):
        base_dataset = build_image_caption_jsonl_dataset(config.dataset)
    elif isinstance(config.dataset, ImageCaptionDirDatasetConfig):
        base_dataset = build_image_caption_dir_dataset(config.dataset)
    else:
        raise ValueError(f"Unexpected dataset config type: '{type(config.dataset)}'.")

    # Initialize either the fixed target resolution or aspect ratio buckets.
    if config.aspect_ratio_buckets is None:
        target_resolution = config.resolution
        aspect_ratio_bucket_manager = None
        batch_sampler = None
    else:
        target_resolution = None
        aspect_ratio_bucket_manager = build_aspect_ratio_bucket_manager(config=config.aspect_ratio_buckets)
        # TODO(ryand): Drill-down the seed parameter rather than hard-coding to 0 here.
        batch_sampler = AspectRatioBucketBatchSampler.from_image_sizes(
            bucket_manager=aspect_ratio_bucket_manager,
            image_sizes=base_dataset.get_image_dimensions(),
            batch_size=batch_size,
            shuffle=shuffle,
            seed=0,
        )

    all_transforms = []

    if config.caption_prefix is not None:
        all_transforms.append(CaptionPrefixTransform(caption_field_name="caption", prefix=config.caption_prefix + " "))

    if vae_output_cache_dir is None:
        image_field_names = ["image"]
        if use_masks:
            image_field_names.append("mask")
        else:
            all_transforms.append(DropFieldTransform("mask"))

        all_transforms.append(
            SDImageTransform(
                image_field_names=image_field_names,
                fields_to_normalize_to_range_minus_one_to_one=["image"],
                resolution=target_resolution,
                aspect_ratio_bucket_manager=aspect_ratio_bucket_manager,
                center_crop=config.center_crop,
                random_flip=config.random_flip,
            )
        )
    else:
        # We drop the image to avoid having to either convert from PIL, or handle PIL batch collation.
        all_transforms.append(DropFieldTransform("image"))
        all_transforms.append(DropFieldTransform("mask"))

        vae_cache = TensorDiskCache(vae_output_cache_dir)

        cache_field_to_output_field = {
            "vae_output": "vae_output",
            "original_size_hw": "original_size_hw",
            "crop_top_left_yx": "crop_top_left_yx",
        }
        if use_masks:
            cache_field_to_output_field["mask"] = "mask"
        all_transforms.append(
            LoadCacheTransform(
                cache=vae_cache,
                cache_key_field="id",
                cache_field_to_output_field=cache_field_to_output_field,
            )
        )

    if text_encoder_output_cache_dir is not None:
        assert text_encoder_cache_field_to_output_field is not None
        text_encoder_cache = TensorDiskCache(text_encoder_output_cache_dir)
        all_transforms.append(
            LoadCacheTransform(
                cache=text_encoder_cache,
                cache_key_field="id",
                cache_field_to_output_field=text_encoder_cache_field_to_output_field,
            )
        )

    dataset = TransformDataset(base_dataset, all_transforms)

    if batch_sampler is None:
        return DataLoader(
            dataset,
            shuffle=shuffle,
            collate_fn=sd_image_caption_collate_fn,
            batch_size=batch_size,
            num_workers=config.dataloader_num_workers,
        )
    else:
        return DataLoader(
            dataset,
            batch_sampler=batch_sampler,
            collate_fn=sd_image_caption_collate_fn,
            num_workers=config.dataloader_num_workers,
        )


================================================
FILE: src/invoke_training/_shared/data/data_loaders/image_pair_preference_sd_dataloader.py
================================================
import typing

import torch
from torch.utils.data import DataLoader

from invoke_training._shared.data.datasets.build_dataset import build_hf_image_pair_preference_dataset
from invoke_training._shared.data.datasets.image_pair_preference_dataset import ImagePairPreferenceDataset
from invoke_training._shared.data.datasets.transform_dataset import TransformDataset
from invoke_training._shared.data.transforms.load_cache_transform import LoadCacheTransform
from invoke_training._shared.data.transforms.sd_image_transform import SDImageTransform
from invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache
from invoke_training.pipelines._experimental.sd_dpo_lora.config import ImagePairPreferenceSDDataLoaderConfig


def sd_image_pair_preference_collate_fn(examples):
    """A batch collation function."""

    stack_keys = {"image_0", "image_1", "prompt_embeds", "pooled_prompt_embeds", "text_encoder_output", "vae_output"}
    list_keys = {
        "id",
        "original_size_hw_0",
        "original_size_hw_1",
        "crop_top_left_yx_0",
        "crop_top_left_yx_1",
        "prefer_0",
        "prefer_1",
        "caption",
    }

    unhandled_keys = set(examples[0].keys()) - (stack_keys | list_keys)
    if len(unhandled_keys) > 0:
        raise ValueError(f"The following keys are not handled by the collate function: {unhandled_keys}.")

    out_examples = {}

    # torch.stack(...)
    for k in stack_keys:
        if k in examples[0]:
            out_examples[k] = torch.stack([example[k] for example in examples])

    # Basic list.
    for k in list_keys:
        if k in examples[0]:
            out_examples[k] = [example[k] for example in examples]

    return out_examples


def build_image_pair_preference_sd_dataloader(
    config: ImagePairPreferenceSDDataLoaderConfig,
    batch_size: int,
    text_encoder_output_cache_dir: typing.Optional[str] = None,
    text_encoder_cache_field_to_output_field: typing.Optional[dict[str, str]] = None,
    vae_output_cache_dir: typing.Optional[str] = None,
    shuffle: bool = True,
) -> DataLoader:
    """Construct a DataLoader for an image-caption dataset for Stable Diffusion XL.

    Args:
        config (ImageCaptionSDDataLoaderConfig): The dataset config.
        batch_size (int): The DataLoader batch size.
        text_encoder_output_cache_dir (str, optional): The directory where text encoder outputs are cached and should be
            loaded from. If set, then the TokenizeTransform will not be applied.
        vae_output_cache_dir (str, optional): The directory where VAE outputs are cached and should be loaded from. If
            set, then the image augmentation transforms will be skipped, and the image will not be copied to VRAM.
        shuffle (bool, optional): Whether to shuffle the dataset order.
    Returns:
        DataLoader
    """
    if config.dataset.type == "HF_HUB_IMAGE_PAIR_PREFERENCE_DATASET":
        base_dataset = build_hf_image_pair_preference_dataset(config=config.dataset)
    elif config.dataset.type == "IMAGE_PAIR_PREFERENCE_DATASET":
        base_dataset = ImagePairPreferenceDataset(dataset_dir=config.dataset.dataset_dir)
    else:
        raise ValueError(f"Unexpected dataset config type: '{type(config.dataset)}'.")

    target_resolution = config.resolution

    all_transforms = []
    if vae_output_cache_dir is None:
        # TODO(ryand): Should I process both images in a single SDImageTransform so that they undergo the same
        # transformations?
        all_transforms.append(
            SDImageTransform(
                image_field_names=["image_0"],
                fields_to_normalize_to_range_minus_one_to_one=["image_0"],
                resolution=target_resolution,
                aspect_ratio_bucket_manager=None,
                center_crop=config.center_crop,
                random_flip=config.random_flip,
                orig_size_field_name="original_size_hw_0",
                crop_field_name="crop_top_left_yx_0",
            )
        )
        all_transforms.append(
            SDImageTransform(
                image_field_names=["image_1"],
                fields_to_normalize_to_range_minus_one_to_one=["image_1"],
                resolution=target_resolution,
                aspect_ratio_bucket_manager=None,
                center_crop=config.center_crop,
                random_flip=config.random_flip,
                orig_size_field_name="original_size_hw_1",
                crop_field_name="crop_top_left_yx_1",
            )
        )
    else:
        raise NotImplementedError("VAE caching is not yet implemented.")
        # vae_cache = TensorDiskCache(vae_output_cache_dir)
        # all_transforms.append(
        #     LoadCacheTransform(
        #         cache=vae_cache,
        #         cache_key_field="id",
        #         cache_field_to_output_field={
        #             "vae_output": "vae_output",
        #             "original_size_hw": "original_size_hw",
        #             "crop_top_left_yx": "crop_top_left_yx",
        #         },
        #     )
        # )
        # # We drop the image to avoid having to either convert from PIL, or handle PIL batch collation.
        # all_transforms.append(DropFieldTransform("image"))

    if text_encoder_output_cache_dir is not None:
        assert text_encoder_cache_field_to_output_field is not None
        text_encoder_cache = TensorDiskCache(text_encoder_output_cache_dir)
        all_transforms.append(
            LoadCacheTransform(
                cache=text_encoder_cache,
                cache_key_field="id",
                cache_field_to_output_field=text_encoder_cache_field_to_output_field,
            )
        )

    dataset = TransformDataset(base_dataset, all_transforms)

    return DataLoader(
        dataset,
        shuffle=shuffle,
        collate_fn=sd_image_pair_preference_collate_fn,
        batch_size=batch_size,
        num_workers=config.dataloader_num_workers,
    )


================================================
FILE: src/invoke_training/_shared/data/data_loaders/textual_inversion_sd_dataloader.py
================================================
from typing import Literal, Optional

from torch.utils.data import DataLoader

from invoke_training._shared.data.data_loaders.image_caption_sd_dataloader import (
    build_aspect_ratio_bucket_manager,
    sd_image_caption_collate_fn,
)
from invoke_training._shared.data.datasets.build_dataset import (
    build_hf_hub_image_caption_dataset,
    build_image_caption_dir_dataset,
    build_image_caption_jsonl_dataset,
)
from invoke_training._shared.data.datasets.image_dir_dataset import ImageDirDataset
from invoke_training._shared.data.datasets.transform_dataset import TransformDataset
from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import AspectRatioBucketBatchSampler
from invoke_training._shared.data.transforms.concat_fields_transform import ConcatFieldsTransform
from invoke_training._shared.data.transforms.drop_field_transform import DropFieldTransform
from invoke_training._shared.data.transforms.load_cache_transform import LoadCacheTransform
from invoke_training._shared.data.transforms.sd_image_transform import SDImageTransform
from invoke_training._shared.data.transforms.shuffle_caption_transform import ShuffleCaptionTransform
from invoke_training._shared.data.transforms.template_caption_transform import TemplateCaptionTransform
from invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache
from invoke_training.config.data.data_loader_config import TextualInversionSDDataLoaderConfig
from invoke_training.config.data.dataset_config import (
    HFHubImageCaptionDatasetConfig,
    ImageCaptionDirDatasetConfig,
    ImageCaptionJsonlDatasetConfig,
    ImageDirDatasetConfig,
)


def get_preset_ti_caption_templates(preset: Literal["object", "style"]) -> list[str]:
    if preset == "object":
        return [
            "a photo of a {}",
            "a rendering of a {}",
            "a cropped photo of the {}",
            "the photo of a {}",
            "a photo of a clean {}",
            "a photo of a dirty {}",
            "a dark photo of the {}",
            "a photo of my {}",
            "a photo of the cool {}",
            "a close-up photo of a {}",
            "a bright photo of the {}",
            "a cropped photo of a {}",
            "a photo of the {}",
            "a good photo of the {}",
            "a photo of one {}",
            "a close-up photo of the {}",
            "a rendition of the {}",
            "a photo of the clean {}",
            "a rendition of a {}",
            "a photo of a nice {}",
            "a good photo of a {}",
            "a photo of the nice {}",
            "a photo of the small {}",
            "a photo of the weird {}",
            "a photo of the large {}",
            "a photo of a cool {}",
            "a photo of a small {}",
        ]
    elif preset == "style":
        return [
            "a painting in the style of {}",
            "a rendering in the style of {}",
            "a cropped painting in the style of {}",
            "the painting in the style of {}",
            "a clean painting in the style of {}",
            "a dirty painting in the style of {}",
            "a dark painting in the style of {}",
            "a picture in the style of {}",
            "a cool painting in the style of {}",
            "a close-up painting in the style of {}",
            "a bright painting in the style of {}",
            "a good painting in the style of {}",
            "a close-up painting in the style of {}",
            "a rendition in the style of {}",
            "a nice painting in the style of {}",
            "a small painting in the style of {}",
            "a weird painting in the style of {}",
            "a large painting in the style of {}",
            "a photo in the style of {}",
            "an image in the style of {}",
            "a drawing in the style of {}",
            "a sketch in the style of {}",
            "a digital work in the style of {}",
            "a digital rendering in the style of {}",
            "a photograph in the style of {}",
            "photography in the style of {}",
        ]
    else:
        raise ValueError(f"Unrecognized learnable property type: '{preset}'.")


def build_textual_inversion_sd_dataloader(  # noqa: C901
    config: TextualInversionSDDataLoaderConfig,
    placeholder_token: str,
    batch_size: int,
    use_masks: bool = False,
    vae_output_cache_dir: Optional[str] = None,
    shuffle: bool = True,
) -> DataLoader:
    """Construct a DataLoader for a Textual Inversion dataset for Stable Diffusion.

    Args:
        config (TextualInversionSDDataLoaderConfig): The dataset config.
        placeholder_token (str): The placeholder token being trained.
        batch_size (int): The DataLoader batch size.
        vae_output_cache_dir (str, optional): The directory where VAE outputs are cached and should be loaded from. If
            set, then the image augmentation transforms will be skipped, and the image will not be copied to VRAM.
        shuffle (bool, optional): Whether to shuffle the dataset order.
    Returns:
        DataLoader
    """
    if isinstance(config.dataset, HFHubImageCaptionDatasetConfig):
        base_dataset = build_hf_hub_image_caption_dataset(config.dataset)
    elif isinstance(config.dataset, ImageCaptionJsonlDatasetConfig):
        base_dataset = build_image_caption_jsonl_dataset(config.dataset)
    elif isinstance(config.dataset, ImageCaptionDirDatasetConfig):
        base_dataset = build_image_caption_dir_dataset(config.dataset)
    elif isinstance(config.dataset, ImageDirDatasetConfig):
        base_dataset = ImageDirDataset(
            image_dir=config.dataset.dataset_dir, keep_in_memory=config.dataset.keep_in_memory
        )
    else:
        raise ValueError(f"Unexpected dataset config type: '{type(config.dataset)}'.")

    # Initialize either the fixed target resolution or aspect ratio buckets.
    if config.aspect_ratio_buckets is None:
        target_resolution = config.resolution
        aspect_ratio_bucket_manager = None
        batch_sampler = None
    else:
        target_resolution = None
        aspect_ratio_bucket_manager = build_aspect_ratio_bucket_manager(config=config.aspect_ratio_buckets)
        # TODO(ryand): Drill-down the seed parameter rather than hard-coding to 0 here.
        batch_sampler = AspectRatioBucketBatchSampler.from_image_sizes(
            bucket_manager=aspect_ratio_bucket_manager,
            image_sizes=base_dataset.get_image_dimensions(),
            batch_size=batch_size,
            shuffle=shuffle,
            seed=0,
        )

    if sum([config.caption_templates is not None, config.caption_preset is not None]) != 1:
        raise ValueError("Either caption_templates or caption_preset must be set.")

    if config.caption_templates is not None:
        # Overwrites the caption field. Typically used with a ImageDirDataset that does not have captions.
        caption_tf = TemplateCaptionTransform(
            field_name="caption_prefix" if config.keep_original_captions else "caption",
            placeholder_str=placeholder_token,
            caption_templates=config.caption_templates,
        )
    elif config.caption_preset is not None:
        # Overwrites the caption field. Typically used with a ImageDirDataset that does not have captions.
        caption_tf = TemplateCaptionTransform(
            field_name="caption_prefix" if config.keep_original_captions else "caption",
            placeholder_str=placeholder_token,
            caption_templates=get_preset_ti_caption_templates(config.caption_preset),
        )
    else:
        raise ValueError("Either caption_templates or caption_preset must be set.")

    all_transforms = [caption_tf]

    if config.keep_original_captions:
        # This will only work with a HFHubImageCaptionDataset or HFDirImageCaptionDataset that already has captions.
        all_transforms.append(
            ConcatFieldsTransform(
                src_field_names=["caption_prefix", "caption"], dst_field_name="caption", separator=" "
            )
        )

    if config.shuffle_caption_delimiter is not None:
        all_transforms.append(ShuffleCaptionTransform(field_name="caption", delimiter=config.shuffle_caption_delimiter))

    if vae_output_cache_dir is None:
        image_field_names = ["image"]
        if use_masks:
            image_field_names.append("mask")
        else:
            all_transforms.append(DropFieldTransform("mask"))

        all_transforms.append(
            SDImageTransform(
                image_field_names=image_field_names,
                fields_to_normalize_to_range_minus_one_to_one=["image"],
                resolution=target_resolution,
                aspect_ratio_bucket_manager=aspect_ratio_bucket_manager,
                center_crop=config.center_crop,
                random_flip=config.random_flip,
            )
        )
    else:
        # We drop the image to avoid having to either convert from PIL, or handle PIL batch collation.
        all_transforms.append(DropFieldTransform("image"))
        all_transforms.append(DropFieldTransform("mask"))

        vae_cache = TensorDiskCache(vae_output_cache_dir)

        cache_field_to_output_field = {
            "vae_output": "vae_output",
            "original_size_hw": "original_size_hw",
            "crop_top_left_yx": "crop_top_left_yx",
        }
        if use_masks:
            cache_field_to_output_field["mask"] = "mask"

        all_transforms.append(
            LoadCacheTransform(
                cache=vae_cache,
                cache_key_field="id",
                cache_field_to_output_field=cache_field_to_output_field,
            )
        )

    dataset = TransformDataset(base_dataset, all_transforms)

    if batch_sampler is None:
        return DataLoader(
            dataset,
            shuffle=shuffle,
            collate_fn=sd_image_caption_collate_fn,
            batch_size=batch_size,
            num_workers=config.dataloader_num_workers,
            persistent_workers=config.dataloader_num_workers > 0,
        )
    else:
        return DataLoader(
            dataset,
            batch_sampler=batch_sampler,
            collate_fn=sd_image_caption_collate_fn,
            num_workers=config.dataloader_num_workers,
            persistent_workers=config.dataloader_num_workers > 0,
        )


================================================
FILE: src/invoke_training/_shared/data/datasets/__init__.py
================================================


================================================
FILE: src/invoke_training/_shared/data/datasets/build_dataset.py
================================================
from datasets import VerificationMode

from invoke_training._shared.data.datasets.hf_image_caption_dataset import HFImageCaptionDataset
from invoke_training._shared.data.datasets.hf_image_pair_preference_dataset import HFImagePairPreferenceDataset
from invoke_training._shared.data.datasets.image_caption_dir_dataset import ImageCaptionDirDataset
from invoke_training._shared.data.datasets.image_caption_jsonl_dataset import ImageCaptionJsonlDataset
from invoke_training.config.data.dataset_config import (
    HFHubImageCaptionDatasetConfig,
    ImageCaptionDirDatasetConfig,
    ImageCaptionJsonlDatasetConfig,
)
from invoke_training.pipelines._experimental.sd_dpo_lora.config import HFHubImagePairPreferenceDatasetConfig


def build_hf_hub_image_caption_dataset(config: HFHubImageCaptionDatasetConfig) -> HFImageCaptionDataset:
    return HFImageCaptionDataset.from_hub(
        dataset_name=config.dataset_name,
        hf_load_dataset_kwargs={
            "name": config.dataset_config_name,
            "cache_dir": config.hf_cache_dir,
        },
        image_column=config.image_column,
        caption_column=config.caption_column,
    )


def build_image_caption_jsonl_dataset(config: ImageCaptionJsonlDatasetConfig) -> HFImageCaptionDataset:
    return ImageCaptionJsonlDataset(
        jsonl_path=config.jsonl_path,
        image_column=config.image_column,
        caption_column=config.caption_column,
        keep_in_memory=config.keep_in_memory,
    )


def build_image_caption_dir_dataset(config: ImageCaptionDirDatasetConfig) -> ImageCaptionDirDataset:
    return ImageCaptionDirDataset(
        dataset_dir=config.dataset_dir,
        keep_in_memory=config.keep_in_memory,
    )


def build_hf_image_pair_preference_dataset(
    config: HFHubImagePairPreferenceDatasetConfig,
) -> HFImagePairPreferenceDataset:
    # HACK(ryand): This is currently hard-coded to just download a small slice of the very large
    # 'yuvalkirstain/pickapic_v2' dataset.
    return HFImagePairPreferenceDataset.from_hub(
        "yuvalkirstain/pickapic_v2",
        split="train",
        hf_load_dataset_kwargs={
            "data_files": {
                # "validation_unique": "data/validation_unique-00000-of-00001-33ead111845fc9c4.parquet",
                "train": [
                    "data/train-00000-of-00645-b66ac786bf6fb553.parquet",
                    "data/train-00001-of-00645-c7b349dd222d6515.parquet",
                    "data/train-00002-of-00645-e4f54d615a978deb.parquet",
                    "data/train-00003-of-00645-2b9d59bac8b433ff.parquet",
                    "data/train-00004-of-00645-e4964649dc0ea543.parquet",
                    "data/train-00005-of-00645-45e8efc0fe93f6e9.parquet",
                ]
            },
            # Disable checks so that it doesn't complain that I haven't downloaded the other splits.
            "verification_mode": VerificationMode.NO_CHECKS,
        },
    )


================================================
FILE: src/invoke_training/_shared/data/datasets/hf_image_caption_dataset.py
================================================
import os
import typing

import datasets
import torch.utils.data
from PIL.Image import Image

from invoke_training._shared.data.utils.resolution import Resolution


class HFImageCaptionDataset(torch.utils.data.Dataset):
    """An image-caption dataset wrapper for Hugging Face datasets.

    The wrapped HF dataset can be either from the HF hub, or in Imagefolder format
    (https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder).
    """

    def __init__(self, hf_dataset, image_column: str = "image", caption_column: str = "text"):
        column_names = hf_dataset["train"].column_names
        if image_column not in column_names:
            raise ValueError(
                f"The image_column='{image_column}' is not in the set of dataset column names: '{column_names}'."
            )

        if caption_column not in column_names:
            raise ValueError(
                f"The caption_column='{caption_column}' is not in the set of dataset column names: '{column_names}'."
            )

        self._image_column = image_column

        def preprocess(examples):
            images = [image.convert("RGB") for image in examples[image_column]]
            return {
                "image": images,
                "caption": examples[caption_column],
            }

        self._hf_dataset = hf_dataset["train"].with_transform(preprocess)

    @classmethod
    def from_dir(
        cls,
        dataset_dir: str,
        hf_load_dataset_kwargs: typing.Optional[dict[str, typing.Any]] = None,
        image_column: str = "image",
        caption_column: str = "text",
    ):
        """Initialize a HFImageCaptionDataset from a Hugging Face ImageFolder dataset directory
        (https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder).

        Args:
            dataset_dir (str): The path to the dataset directory.
            hf_load_dataset_kwargs (dict[str, typing.Any], optional): kwargs to forward to `datasets.load_dataset(...)`.
            image_column (str, optional): The name of the image column in the dataset. Defaults to "image".
            caption_column (str, optional): The name of the caption column in the dataset. Defaults to "text".
        """
        hf_load_dataset_kwargs = hf_load_dataset_kwargs or {}
        data_files = {"train": os.path.join(dataset_dir, "**")}
        # See more about loading custom images at
        # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
        hf_dataset = datasets.load_dataset("imagefolder", data_files=data_files, **hf_load_dataset_kwargs)

        return cls(hf_dataset=hf_dataset, image_column=image_column, caption_column=caption_column)

    @classmethod
    def from_hub(
        cls,
        dataset_name: str,
        hf_load_dataset_kwargs: typing.Optional[dict[str, typing.Any]] = None,
        image_column: str = "image",
        caption_column: str = "text",
    ):
        """Initialize a HFImageCaptionDataset from a Hugging Face Hub dataset.

        Args:
            dataset_name (str): The HF Hub dataset name (a.k.a. path).
            hf_load_dataset_kwargs (dict[str, typing.Any], optional): kwargs to forward to `datasets.load_dataset(...)`.
            image_column (str, optional): The name of the image column in the dataset. Defaults to "image".
            caption_column (str, optional): The name of the caption column in the dataset. Defaults to "text".
        """
        hf_load_dataset_kwargs = hf_load_dataset_kwargs or {}
        hf_dataset = datasets.load_dataset(dataset_name, **hf_load_dataset_kwargs)

        return cls(hf_dataset=hf_dataset, image_column=image_column, caption_column=caption_column)

    def get_image_dimensions(self) -> list[Resolution]:
        """Get the dimensions of all images in the dataset.

        TODO(ryand): Re-think this approach. For large datasets (e.g. streaming from S3) it doesn't make sense to
        calculate this dynamically every time.
        """
        image_dims: list[Resolution] = []
        for i in range(len(self._hf_dataset)):
            example = self._hf_dataset[i]
            image: Image = example[self._image_column]
            image_dims.append(Resolution(image.height, image.width))

        return image_dims

    def __len__(self) -> int:
        """Get the dataset length.

        Returns:
            int: The number of image-caption pairs in the dataset.
        """
        return len(self._hf_dataset)

    def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:
        """Load the dataset example at index `idx`.

        Raises:
            IndexError: If `idx` is out of range.

        Returns:
            dict: A dataset example with 3 keys: "image", "caption", and "id".
                The "image" key maps to a `PIL` image in RGB format.
                The "caption" key maps to a string.
                The "id" key is the example's index (often used for caching).
        """
        example = self._hf_dataset[idx]
        example["id"] = idx
        return example


================================================
FILE: src/invoke_training/_shared/data/datasets/hf_image_pair_preference_dataset.py
================================================
import io
import typing

import datasets
import torch.utils.data
from PIL import Image


class HFImagePairPreferenceDataset(torch.utils.data.Dataset):
    """A wrapper for the Hugging Face hub "yuvalkirstain/pickapic_v2" dataset
    (https://huggingface.co/datasets/yuvalkirstain/pickapic_v2).

    Designed to be expanded in the future to other HF image pair preference datasets.
    """

    def __init__(
        self,
        hf_dataset,
        skip_no_preference=True,
        split: str = "train",
        image_0_column: str = "jpg_0",
        label_0_column: str = "label_0",
        image_1_column: str = "jpg_1",
        label_1_column: str = "jpg_1",
        caption_column: str = "caption",
    ):
        """
        Args:
            skip_no_preference (bool, optional): If True, skip image pairs without a preference.
        """
        column_names = hf_dataset[split].column_names

        for col_name in [image_0_column, label_0_column, image_1_column, label_1_column, caption_column]:
            if col_name not in column_names:
                raise ValueError(f"Column '{col_name}' is not in the set of dataset column names: '{column_names}'.")

        eps = 0.0001

        if skip_no_preference:
            # Filter to only include pairs with a clear preference.
            def filter(example: dict[str, typing.Any]) -> bool:
                return abs(example["label_0"] - example["label_1"]) > eps

            hf_dataset = hf_dataset.filter(filter)

        def preprocess(examples):
            image_0_list = [Image.open(io.BytesIO(image)).convert("RGB") for image in examples[image_0_column]]
            image_1_list = [Image.open(io.BytesIO(image)).convert("RGB") for image in examples[image_1_column]]

            image_0_is_better = []
            image_1_is_better = []
            for label_0, label_1 in zip(examples["label_0"], examples["label_1"]):
                if (label_0 - label_1) > eps:
                    # Label 0 is better.
                    image_0_is_better.append(True)
                    image_1_is_better.append(False)
                elif (label_1 - label_0) > eps:
                    # Label 1 is better.
                    image_0_is_better.append(False)
                    image_1_is_better.append(True)
                else:
                    # Tie.
                    image_0_is_better.append(False)
                    image_1_is_better.append(False)

            return {
                "image_0": image_0_list,
                "image_1": image_1_list,
                "prefer_0": image_0_is_better,
                "prefer_1": image_1_is_better,
                "caption": examples[caption_column],
            }

        self._hf_dataset = hf_dataset[split].with_transform(preprocess)

    @classmethod
    def from_hub(
        cls,
        dataset_name: str,
        skip_no_preference: bool = True,
        split: str = "train",
        hf_load_dataset_kwargs: typing.Optional[dict[str, typing.Any]] = None,
    ):
        """Initialize a HFImageCaptionDataset from a Hugging Face Hub dataset.

        Args:
            dataset_name (str): The HF Hub dataset name (a.k.a. path).
            hf_load_dataset_kwargs (dict[str, typing.Any], optional): kwargs to forward to `datasets.load_dataset(...)`.
        """
        if dataset_name != "yuvalkirstain/pickapic_v2":
            raise NotImplementedError(
                "The HFImagePairPreferenceDataset class likely won't work with datasets other than "
                "'yuvalkirstain/pickapic_v2'."
            )

        hf_load_dataset_kwargs = hf_load_dataset_kwargs or {}
        hf_dataset = datasets.load_dataset(dataset_name, **hf_load_dataset_kwargs)

        return cls(hf_dataset=hf_dataset, skip_no_preference=skip_no_preference, split=split)

    def __len__(self) -> int:
        """Get the dataset length.

        Returns:
            int: The number of image pairs in the dataset.
        """
        return len(self._hf_dataset)

    def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:
        """Load the dataset example at index `idx`.

        Raises:
            IndexError: If `idx` is out of range.

        Returns:
            dict: A dataset example with the following keys: ["id", "image_1", "caption_1", "image_2", "caption_2",
                "prefer_1", "prefer_2"]
                The image keys map to a `PIL` image in RGB format.
                The caption keys map to strings.
                The "id" key is the example's index (often used for caching).
        """
        example = self._hf_dataset[idx]
        example["id"] = idx
        return example


================================================
FILE: src/invoke_training/_shared/data/datasets/image_caption_dir_dataset.py
================================================
import os
import typing

import torch.utils.data
from PIL import Image

from invoke_training._shared.data.utils.resolution import Resolution


class ImageCaptionDirDataset(torch.utils.data.Dataset):
    """A dataset that loads images and captions from a directory of image files and .txt files."""

    def __init__(
        self,
        dataset_dir: str,
        id_prefix: str = "",
        image_extensions: typing.Optional[list[str]] = None,
        caption_extension: str = ".txt",
        keep_in_memory: bool = False,
    ):
        """Initialize an ImageDirDataset

        Args:
            image_dir (str): The directory to load images from.
            id_prefix (str): A prefix added to the 'id' field in every example.
            image_extensions (list[str], optional): The list of image file extensions to include in the dataset (not
                case-sensitive). Defaults to [".jpg", ".jpeg", ".png"].
            keep_in_memory (bool, optional): If True, keep all images loaded in memory. This improves performance for
                datasets that are small enough to be kept in memory.
        """
        super().__init__()
        self._id_prefix = id_prefix
        if image_extensions is None:
            image_extensions = [".jpg", ".jpeg", ".png"]
        image_extensions = [ext.lower() for ext in image_extensions]

        # Determine the list of image paths to include in the dataset.
        self._image_paths: list[str] = []
        for image_file in os.listdir(dataset_dir):
            image_path = os.path.join(dataset_dir, image_file)
            if os.path.isfile(image_path) and os.path.splitext(image_path)[1].lower() in image_extensions:
                self._image_paths.append(image_path)
        self._image_paths.sort()

        # Load captions from .txt files for each image.
        self._captions: list[str] = []
        missing_captions: list[str] = []
        for image_path in self._image_paths:
            caption_path = os.path.splitext(image_path)[0] + caption_extension
            if os.path.isfile(caption_path):
                with open(caption_path, "r") as f:
                    self._captions.append(f.read().strip())
            else:
                missing_captions.append(caption_path)
        if len(missing_captions) > 0:
            raise Exception(f"The following expected caption files are missing: {missing_captions}")

        self._images = None
        if keep_in_memory:
            self._images = []
            for image_path in self._image_paths:
                self._images.append(self._load_image(image_path))

    def _load_image(self, image_path: str) -> Image.Image:
        # We call `convert("RGB")` to drop the alpha channel from RGBA images, or to repeat channels for greyscale
        # images.
        return Image.open(image_path).convert("RGB")

    def get_image_dimensions(self) -> list[Resolution]:
        """Get the dimensions of all images in the dataset.

        TODO(ryand): Re-think this approach. For large datasets (e.g. streaming from S3) it doesn't make sense to
        calculate this dynamically every time.
        """
        image_dims: list[Resolution] = []
        for i in range(len(self._image_paths)):
            image_path = self._image_paths[i]
            image = Image.open(image_path)
            image_dims.append(Resolution(image.height, image.width))

        return image_dims

    def __len__(self) -> int:
        return len(self._image_paths)

    def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:
        image = self._images[idx] if self._images is not None else self._load_image(self._image_paths[idx])
        return {"id": f"{self._id_prefix}{idx}", "image": image, "caption": self._captions[idx]}


================================================
FILE: src/invoke_training/_shared/data/datasets/image_caption_jsonl_dataset.py
================================================
import typing
from pathlib import Path

import torch.utils.data
from PIL import Image
from pydantic import BaseModel

from invoke_training._shared.data.utils.resolution import Resolution
from invoke_training._shared.utils.jsonl import load_jsonl, save_jsonl

IMAGE_COLUMN_DEFAULT = "image"
CAPTION_COLUMN_DEFAULT = "text"
MASK_COLUMN_DEFAULT = "mask"


class ImageCaptionExample(BaseModel):
    image_path: str
    mask_path: str | None = None
    caption: str


class ImageCaptionJsonlDataset(torch.utils.data.Dataset):
    """A dataset that loads images and captions from a directory of image files and .txt files."""

    def __init__(
        self,
        jsonl_path: Path | str,
        image_column: str = IMAGE_COLUMN_DEFAULT,
        caption_column: str = CAPTION_COLUMN_DEFAULT,
        keep_in_memory: bool = False,
    ):
        super().__init__()
        self._jsonl_path = Path(jsonl_path)
        self._image_column = image_column
        self._caption_column = caption_column

        data = load_jsonl(jsonl_path)
        examples: list[ImageCaptionExample] = []
        for d in data:
            # Clear error messages here are helpful in the Gradio UI.
            if image_column not in d:
                raise ValueError(f"Column '{image_column}' not found in jsonl file '{jsonl_path}'.")
            if caption_column not in d:
                raise ValueError(f"Column '{caption_column}' not found in jsonl file '{jsonl_path}'.")
            examples.append(
                ImageCaptionExample(
                    image_path=d[image_column], mask_path=d.get(MASK_COLUMN_DEFAULT, None), caption=d[caption_column]
                )
            )
        self.examples = examples

        self._keep_in_memory = keep_in_memory
        self._example_cache: dict[int, dict[str, typing.Any]] = {}

    def save_jsonl(self):
        data = []
        for example in self.examples:
            data.append(
                {
                    self._image_column: example.image_path,
                    self._caption_column: example.caption,
                    MASK_COLUMN_DEFAULT: example.mask_path,
                }
            )
        save_jsonl(data, self._jsonl_path)

    def _get_image_path(self, idx: int) -> str:
        image_path = self.examples[idx].image_path
        image_path = Path(image_path)

        # image_path could be either absolute, or relative to the jsonl file.
        if not image_path.is_absolute():
            image_path = self._jsonl_path.parent / image_path

        return image_path

    def _get_mask_path(self, idx: int) -> str:
        mask_path = self.examples[idx].mask_path
        mask_path = Path(mask_path)

        # mask_path could be either absolute, or relative to the jsonl file.
        if not mask_path.is_absolute():
            mask_path = self._jsonl_path.parent / mask_path

        return mask_path

    def _load_image(self, image_path: str) -> Image.Image:
        # We call `convert("RGB")` to drop the alpha channel from RGBA images, or to repeat channels for greyscale
        # images.
        return Image.open(image_path).convert("RGB")

    def _load_mask(self, mask_path: str) -> Image.Image:
        return Image.open(mask_path).convert("L")

    def _load_example(self, idx: int) -> dict[str, typing.Any]:
        example = {
            "id": str(idx),
            "image": self._load_image(self._get_image_path(idx)),
            "caption": self.examples[idx].caption,
        }
        if self.examples[idx].mask_path:
            example["mask"] = self._load_mask(self._get_mask_path(idx))
        return example

    def get_image_dimensions(self) -> list[Resolution]:
        """Get the dimensions of all images in the dataset.

        TODO(ryand): Re-think this approach. For large datasets (e.g. streaming from S3) it doesn't make sense to
        calculate this dynamically every time.
        """
        image_dims: list[Resolution] = []
        for i in range(len(self.examples)):
            image = Image.open(self._get_image_path(i))
            image_dims.append(Resolution(image.height, image.width))

        return image_dims

    def __len__(self) -> int:
        return len(self.examples)

    def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:
        if self._keep_in_memory:
            if idx not in self._example_cache:
                self._example_cache[idx] = self._load_example(idx)
            # Return a shallow copy of the example to prevent the caller from modifying the cached example.
            # Shallow rather than deep, because we don't want to copy the image data.
            return self._example_cache[idx].copy()
        return self._load_example(idx)


================================================
FILE: src/invoke_training/_shared/data/datasets/image_dir_dataset.py
================================================
import os
import typing

import torch.utils.data
from PIL import Image

from invoke_training._shared.data.utils.resolution import Resolution


class ImageDirDataset(torch.utils.data.Dataset):
    """A dataset that loads image files from a directory."""

    def __init__(
        self,
        image_dir: str,
        id_prefix: str = "",
        image_extensions: typing.Optional[list[str]] = None,
        keep_in_memory: bool = False,
    ):
        """Initialize an ImageDirDataset

        Args:
            image_dir (str): The directory to load images from.
            id_prefix (str): A prefix added to the 'id' field in every example.
            image_extensions (list[str], optional): The list of image file extensions to include in the dataset (not
                case-sensitive). Defaults to [".jpg", ".jpeg", ".png"].
            keep_in_memory (bool, optional): If True, keep all images loaded in memory. This improves performance for
            datasets that are small enough to be kept in memory.
        """
        super().__init__()
        self._id_prefix = id_prefix
        if image_extensions is None:
            image_extensions = [".jpg", ".jpeg", ".png"]
        image_extensions = [ext.lower() for ext in image_extensions]

        self._image_paths = []

        for image_file in os.listdir(image_dir):
            image_path = os.path.join(image_dir, image_file)
            if os.path.isfile(image_path) and os.path.splitext(image_path)[1].lower() in image_extensions:
                self._image_paths.append(image_path)

        self._images = None
        if keep_in_memory:
            self._images = []
            for image_path in self._image_paths:
                self._images.append(self._load_image(image_path))

    def _load_image(self, image_path: str) -> Image.Image:
        # We call `convert("RGB")` to drop the alpha channel from RGBA images, or to repeat channels for greyscale
        # images.
        return Image.open(image_path).convert("RGB")

    def get_image_dimensions(self) -> list[Resolution]:
        """Get the dimensions of all images in the dataset.

        TODO(ryand): Re-think this approach. For large datasets (e.g. streaming from S3) it doesn't make sense to
        calculate this dynamically every time.
        """
        image_dims: list[Resolution] = []
        for i in range(len(self._image_paths)):
            image_path = self._image_paths[i]
            image = Image.open(image_path)
            image_dims.append(Resolution(image.height, image.width))

        return image_dims

    def __len__(self) -> int:
        return len(self._image_paths)

    def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:
        image = self._images[idx] if self._images is not None else self._load_image(self._image_paths[idx])
        return {"id": f"{self._id_prefix}{idx}", "image": image}


================================================
FILE: src/invoke_training/_shared/data/datasets/image_pair_preference_dataset.py
================================================
import os
import typing
from pathlib import Path

import torch.utils.data
from PIL import Image

from invoke_training._shared.utils.jsonl import load_jsonl, save_jsonl


class ImagePairPreferenceDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_dir: str):
        super().__init__()
        self._dataset_dir = dataset_dir

        self._metadata = load_jsonl(Path(dataset_dir) / "metadata.jsonl")

    @classmethod
    def save_metadata(
        cls, metadata: list[dict[str, typing.Any]], dataset_dir: str | Path, metadata_file: str = "metadata.jsonl"
    ) -> Path:
        """Load the dataset metadata from metadata.jsonl."""
        metadata_path = Path(dataset_dir) / metadata_file
        save_jsonl(metadata, metadata_path)
        return metadata_path

    def __len__(self) -> int:
        return len(self._metadata)

    def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:
        # We call `convert("RGB")` to drop the alpha channel from RGBA images, or to repeat channels for greyscale
        # images.
        example = self._metadata[idx]
        image_0_path = os.path.join(self._dataset_dir, example["image_0"])
        image_1_path = os.path.join(self._dataset_dir, example["image_1"])
        return {
            "id": str(idx),
            "image_0": Image.open(image_0_path).convert("RGB"),
            "image_1": Image.open(image_1_path).convert("RGB"),
            "caption": example["prompt"],
            "prefer_0": example["prefer_0"],
            "prefer_1": example["prefer_1"],
        }


================================================
FILE: src/invoke_training/_shared/data/datasets/transform_dataset.py
================================================
import typing

import torch.utils.data

# The data type expected to be produced by the base dataset and handled by transforms.
DataType = typing.Dict[str, typing.Any]

TransformType = typing.Callable[[DataType], DataType]


class TransformDataset(torch.utils.data.Dataset):
    """A Dataset that wraps a base dataset and applies callable transforms to its outputs."""

    def __init__(self, base_dataset: torch.utils.data.Dataset, transforms: list[TransformType]) -> None:
        super().__init__()
        self._base_dataset = base_dataset
        self._transforms = transforms

    def __len__(self) -> int:
        return len(self._base_dataset)

    def __getitem__(self, idx: int) -> DataType:
        example = self._base_dataset[idx]
        for t in self._transforms:
            example = t(example)
        return example


================================================
FILE: src/invoke_training/_shared/data/samplers/__init__.py
================================================


================================================
FILE: src/invoke_training/_shared/data/samplers/aspect_ratio_bucket_batch_sampler.py
================================================
import copy
import logging
import math
import random
from typing import Iterator

from torch.utils.data import Sampler

from invoke_training._shared.data.utils.aspect_ratio_bucket_manager import AspectRatioBucketManager
from invoke_training._shared.data.utils.resolution import Resolution

AspectRatioBuckets = dict[Resolution, list[int]]


class AspectRatioBucketBatchSampler(Sampler[list[int]]):
    """A batch sampler that adheres to aspect ratio buckets."""

    def __init__(
        self,
        buckets: AspectRatioBuckets,
        batch_size: int,
        shuffle: bool = False,
        seed: int | None = None,
    ) -> None:
        """Initialize AspectRatioBucketBatchSampler.

        For most use cases, initialize via AspectRatioBucketBatchSampler.from_image_sizes(...).
        """
        self._buckets = buckets
        self._batch_size = batch_size
        self._shuffle = shuffle
        self._random = random.Random(seed)

    def __str__(self) -> str:
        buckets = self.get_buckets()
        bucket_resolutions = sorted(list(buckets.keys()))
        s = ""
        for bucket_resolution in bucket_resolutions:
            bucket_images = buckets[bucket_resolution]
            s += f"  {bucket_resolution.to_tuple()}: {len(bucket_images)}\n"
        return s

    @classmethod
    def from_image_sizes(
        cls,
        bucket_manager: AspectRatioBucketManager,
        image_sizes: list[Resolution],
        batch_size: int,
        shuffle: bool = False,
        seed: int | None = None,
    ):
        """Initialize from an AspectRatioBucketManager and the list of dataset image resolutions."""
        buckets = cls._build_bucket_to_index_map(bucket_manager, image_sizes)
        return cls(buckets=buckets, batch_size=batch_size, shuffle=shuffle, seed=seed)

    @classmethod
    def _build_bucket_to_index_map(
        cls,
        bucket_manager: AspectRatioBucketManager,
        image_sizes: list[Resolution],
    ) -> AspectRatioBuckets:
        bucket_to_indexes: AspectRatioBuckets = dict()

        for bucket_resolution in bucket_manager.buckets:
            bucket_to_indexes[bucket_resolution] = []

        for index, image_size in enumerate(image_sizes):
            aspect_ratio_bucket = bucket_manager.get_aspect_ratio_bucket(image_size)
            bucket_to_indexes[aspect_ratio_bucket].append(index)

        return bucket_to_indexes

    def get_buckets(self) -> AspectRatioBuckets:
        return copy.deepcopy(self._buckets)

    def __iter__(self) -> Iterator[list[int]]:
        batches: list[list[int]] = []

        # TODO(ryand): If self._shuffle == False, should we still shuffle just with a fixed seed every time? If we
        # don't shuffle at all then all of the batches from a bucket will be grouped together. If there's a correlation
        # between aspect ratio and image content in a dataset, this could result in unevenly distributed image content
        # over the dataset.

        for bucket_resolution in sorted(list(self._buckets.keys())):
            ordered_bucket_images = self._buckets[bucket_resolution].copy()
            if self._shuffle:
                # Shuffle the images within a bucket.
                self._random.shuffle(ordered_bucket_images)

            # Prepare batches for a single bucket.
            batch_start = 0
            while batch_start < len(ordered_bucket_images):
                batch_end = min(batch_start + self._batch_size, len(ordered_bucket_images))
                batches.append(ordered_bucket_images[batch_start:batch_end])
                batch_start += self._batch_size

        if self._shuffle:
            # We've already shuffled the images within each bucket, now we shuffle the batches.
            self._random.shuffle(batches)

        yield from batches

    def __len__(self) -> int:
        num_batches = 0
        for bucket_images in self._buckets.values():
            num_batches += math.ceil(len(bucket_images) / self._batch_size)
        return num_batches


def log_aspect_ratio_buckets(logger: logging.Logger, batch_sampler: AspectRatioBucketBatchSampler):
    """Utility function for logging the aspect ratio buckets."""
    if not isinstance(batch_sampler, AspectRatioBucketBatchSampler):
        return

    log = "Aspect Ratio Buckets:\n"
    log += str(batch_sampler)
    logger.info(log)


================================================
FILE: src/invoke_training/_shared/data/samplers/batch_offset_sampler.py
================================================
import typing

from torch.utils.data import Sampler


class BatchOffsetSampler(Sampler[int]):
    """A sampler that wraps a batch sampler and applies an offset to all returned batch elements."""

    def __init__(self, sampler: Sampler[int], offset: int):
        self._sampler = sampler
        self._offset = offset

    def __iter__(self) -> typing.Iterator[int]:
        for batch in self._sampler:
            offset_batch = [x + self._offset for x in batch]
            yield offset_batch

    def __len__(self) -> int:
        return len(self._sampler)


================================================
FILE: src/invoke_training/_shared/data/samplers/concat_sampler.py
================================================
import itertools
import typing

from torch.utils.data import Sampler

T_co = typing.TypeVar("T_co", covariant=True)


class ConcatSampler(Sampler[T_co]):
    """A meta-Sampler that concatenates multiple samplers.

    Example:
        sampler 1:           ABCD
        sampler 2:           EFG
        sampler 3:           HIJKLM
        ConcatSampler:       ABCDEFGHIJKLM
    """

    def __init__(self, samplers: list[Sampler[T_co] | typing.Iterable[T_co]]) -> None:
        self._samplers = samplers

    def __iter__(self) -> typing.Iterator[T_co]:
        return itertools.chain(*self._samplers)

    def __len__(self) -> int:
        return sum([len(s) for s in self._samplers])


================================================
FILE: src/invoke_training/_shared/data/samplers/interleaved_sampler.py
================================================
import typing

from torch.utils.data import Sampler

T_co = typing.TypeVar("T_co", covariant=True)


class InterleavedSampler(Sampler[T_co]):
    """A meta-Sampler that interleaves multiple samplers.

    The length of this sampler is based on the length of the shortest input sampler. All samplers will contribute the
    same number of samples to the interleaved output.

    Example:
        sampler 1:           ABCD
        sampler 2:           EFG
        sampler 3:           HIJKLM
        interleaved sampler: AEHBFICGJ
    """

    def __init__(self, samplers: list[Sampler[T_co] | typing.Iterable[T_co]]) -> None:
        self._samplers = samplers
        self._min_sampler_len = min([len(s) for s in self._samplers])

    def __iter__(self) -> typing.Iterator[T_co]:
        sampler_iters = [iter(s) for s in self._samplers]
        while True:
            samples = []
            for sampler_iter in sampler_iters:
                try:
                    samples.append(next(sampler_iter))
                except StopIteration:
                    # The end of the shortest sampler has been reached.
                    return

            yield from samples

    def __len__(self) -> int:
        return self._min_sampler_len * len(self._samplers)


================================================
FILE: src/invoke_training/_shared/data/samplers/offset_sampler.py
================================================
import typing

from torch.utils.data import Sampler


class OffsetSampler(Sampler[int]):
    """A sampler that wraps another sampler and applies an offset to all returned values."""

    def __init__(self, sampler: Sampler[int], offset: int):
        self._sampler = sampler
        self._offset = offset

    def __iter__(self) -> typing.Iterator[int]:
        for idx in self._sampler:
            yield idx + self._offset

    def __len__(self) -> int:
        return len(self._sampler)


================================================
FILE: src/invoke_training/_shared/data/transforms/__init__.py
================================================


================================================
FILE: src/invoke_training/_shared/data/transforms/caption_prefix_transform.py
================================================
import typing


class CaptionPrefixTransform:
    """A transform that adds a prefix to all example captions."""

    def __init__(self, caption_field_name: str, prefix: str):
        self._caption_field_name = caption_field_name
        self._prefix = prefix

    def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
        data[self._caption_field_name] = self._prefix + data[self._caption_field_name]
        return data


================================================
FILE: src/invoke_training/_shared/data/transforms/concat_fields_transform.py
================================================
import typing


class ConcatFieldsTransform:
    """A transform that concatenate multiple string fields."""

    def __init__(self, src_field_names: list[str], dst_field_name: str, separator: str = " "):
        self._src_field_names = src_field_names
        self._dst_field_name = dst_field_name
        self._separator = separator

    def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
        result = self._separator.join([data[field_name] for field_name in self._src_field_names])
        data[self._dst_field_name] = result
        return data


================================================
FILE: src/invoke_training/_shared/data/transforms/constant_field_transform.py
================================================
import typing


class ConstantFieldTransform:
    """A simple transform that adds a constant field to every example."""

    def __init__(self, field_name: str, field_value: typing.Any):
        self._field_name = field_name
        self._field_value = field_value

    def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
        data[self._field_name] = self._field_value
        return data


================================================
FILE: src/invoke_training/_shared/data/transforms/drop_field_transform.py
================================================
import typing


class DropFieldTransform:
    """A simple transform that drops a field from an example."""

    def __init__(self, field_to_drop: str):
        self._field_to_drop = field_to_drop

    def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
        if self._field_to_drop in data:
            del data[self._field_to_drop]
        return data


================================================
FILE: src/invoke_training/_shared/data/transforms/flux_image_transform.py
================================================
import typing

from torchvision import transforms
from torchvision.transforms.functional import crop

from invoke_training._shared.data.utils.aspect_ratio_bucket_manager import AspectRatioBucketManager, Resolution
from invoke_training._shared.data.utils.resize import resize_to_cover


class FluxImageTransform:
    """A transform that prepares and augments images for Flux.1-dev training."""

    def __init__(
        self,
        image_field_names: list[str],
        fields_to_normalize_to_range_minus_one_to_one: list[str],
        resolution: int | None = 512,
        aspect_ratio_bucket_manager: AspectRatioBucketManager | None = None,
        random_flip: bool = True,
        center_crop: bool = True,
    ):
        """Initialize FluxImageTransform.

        Args:
            image_field_names (list[str]): The field names of the images to be transformed.
            resolution (int): The image resolution that will be produced. One of `resolution` and
                `aspect_ratio_bucket_manager` should be non-None.
            aspect_ratio_bucket_manager (AspectRatioBucketManager): The AspectRatioBucketManager used to determine the
                target resolution for each image. One of `resolution` and `aspect_ratio_bucket_manager` should be
                non-None.
            center_crop (bool, optional): If True, crop to the center of the image to achieve the target resolution. If
                False, crop at a random location.
            random_flip (bool, optional): Whether to apply a random horizontal flip to the images.
        """
        self.image_field_names = image_field_names
        self.fields_to_normalize_to_range_minus_one_to_one = fields_to_normalize_to_range_minus_one_to_one
        self.resolution = resolution
        self.aspect_ratio_bucket_manager = aspect_ratio_bucket_manager
        self.random_flip = random_flip
        self.center_crop = center_crop

    def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:  # noqa: C901
        image_fields: dict = {}
        for field_name in self.image_field_names:
            image_fields[field_name] = data[field_name]

        # Get the first image to determine original size and resolution
        first_image = next(iter(image_fields.values()))
        original_size_hw = (first_image.height, first_image.width)

        for field_name, image in image_fields.items():
            # Determine the target image resolution.
            if self.resolution is not None:
                resolution = self.resolution
                resolution_obj = Resolution(resolution, resolution)
            else:
                resolution_obj = self.aspect_ratio_bucket_manager.get_aspect_ratio_bucket(
                    Resolution.parse(original_size_hw)
                )

            image = resize_to_cover(image, resolution_obj)

            # Apply cropping and record top left crop position
            if self.center_crop:
                top_left_y = max(0, (image.height - resolution_obj.height) // 2)
                top_left_x = max(0, (image.width - resolution_obj.width) // 2)
                image = transforms.CenterCrop(resolution_obj.to_tuple())(image)
            else:
                crop_transform = transforms.RandomCrop(resolution_obj.to_tuple())
                top_left_y, top_left_x, h, w = crop_transform.get_params(image, resolution_obj.to_tuple())
                image = crop(image, top_left_y, top_left_x, resolution_obj.height, resolution_obj.width)

            # Apply random flip and update top left crop position accordingly
            if self.random_flip:
                # TODO: Use a seed for repeatable results
                import random

                if random.random() < 0.5:
                    top_left_x = original_size_hw[1] - image.width - top_left_x
                    image = transforms.RandomHorizontalFlip(p=1.0)(image)

            image = transforms.ToTensor()(image)

            if field_name in self.fields_to_normalize_to_range_minus_one_to_one:
                image_fields[field_name] = transforms.Normalize([0.5], [0.5])(image)
            else:
                image_fields[field_name] = image

        # Store the processed images and metadata
        for field_name, image in image_fields.items():
            data[field_name] = image

        # Add metadata fields expected by VAE caching
        data["original_size_hw"] = original_size_hw
        data["crop_top_left_yx"] = (top_left_y, top_left_x)

        return data


================================================
FILE: src/invoke_training/_shared/data/transforms/load_cache_transform.py
================================================
import typing

from invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache


class LoadCacheTransform:
    """A transform that loads data from a TensorDiskCache."""

    def __init__(
        self, cache: TensorDiskCache, cache_key_field: str, cache_field_to_output_field: typing.Dict[str, str]
    ):
        """Initialize LoadCacheTransform.

        Args:
            cache (TensorDiskCache): The cache to load from.
            cache_key_field (str): The name of the field to use as the cache key.
            cache_field_to_output_field (typing.Dict[str, str]): A map of field names in the cached data to the field
                names where they should be inserted in the example data.
        """
        self._cache = cache
        self._cache_key_field = cache_key_field
        self._cache_field_to_output_field = cache_field_to_output_field

    def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
        key = data[self._cache_key_field]

        cache_data = self._cache.load(key)

        for src, dst in self._cache_field_to_output_field.items():
            data[dst] = cache_data[src]

        return data


================================================
FILE: src/invoke_training/_shared/data/transforms/sd_image_transform.py
================================================
import random
import typing

from torchvision import transforms
from torchvision.transforms.functional import crop

from invoke_training._shared.data.utils.aspect_ratio_bucket_manager import AspectRatioBucketManager, Resolution
from invoke_training._shared.data.utils.resize import resize_to_cover


class SDImageTransform:
    """A transform that prepares and augments images for Stable Diffusion training."""

    def __init__(
        self,
        image_field_names: list[str],
        fields_to_normalize_to_range_minus_one_to_one: list[str],
        resolution: int | tuple[int, int] | Resolution | None,
        aspect_ratio_bucket_manager: AspectRatioBucketManager | None = None,
        center_crop: bool = True,
        random_flip: bool = False,
        orig_size_field_name: str = "original_size_hw",
        crop_field_name: str = "crop_top_left_yx",
    ):
        """Initialize SDImageTransform.

        Args:
            image_field_names (list[str]): The field names of the images to be transformed.
            resolution (Resolution): The image resolution that will be produced. One of `resolution` and
                `aspect_ratio_bucket_manager` should be non-None.
            aspect_ratio_bucket_manager (AspectRatioBucketManager): The AspectRatioBucketManager used to determine the
                target resolution for each image. One of `resolution` and `aspect_ratio_bucket_manager` should be
                non-None.
            center_crop (bool, optional): If True, crop to the center of the image to achieve the target resolution. If
                False, crop at a random location.
            random_flip (bool, optional): Whether to apply a random horizontal flip to the images.
        """
        self._image_field_names = image_field_names
        self._fields_to_normalize_to_range_minus_one_to_one = fields_to_normalize_to_range_minus_one_to_one
        if resolution is not None and aspect_ratio_bucket_manager is not None:
            raise ValueError("Only one of `resolution` or `aspect_ratio_bucket_manager` should be set.")

        if resolution is None and aspect_ratio_bucket_manager is None:
            raise ValueError("One of `resolution` or `aspect_ratio_bucket_manager` must be set.")

        self._resolution = Resolution.parse(resolution) if resolution is not None else None
        self._aspect_ratio_bucket_manager = aspect_ratio_bucket_manager
        self._center_crop_enabled = center_crop
        self._random_flip_enabled = random_flip
        self._flip_transform = transforms.RandomHorizontalFlip(p=1.0)
        self._to_tensor_transform = transforms.ToTensor()
        # Convert pixel values from range [0, 1.0] to range [-1.0, 1.0].
        # Normalize applies the following transform: out = (in - 0.5) / 0.5
        self._normalize_image_transform = transforms.Normalize([0.5], [0.5])

        self._orig_size_field_name = orig_size_field_name
        self._crop_field_name = crop_field_name

    def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:  # noqa: C901
        # This SDXL image pre-processing logic is adapted from:
        # https://github.com/huggingface/diffusers/blob/7b07f9812a58bfa96c06ed8ffe9e6b584286e2fd/examples/text_to_image/train_text_to_image_lora_sdxl.py#L850-L873

        image_fields: dict = {}
        for field_name in self._image_field_names:
            image_fields[field_name] = data[field_name]
        sizes = [image.size for image in image_fields.values()]
        # All images should have the same size.
        assert all(size == sizes[0] for size in sizes)

        # Helper function to access the first image, which is sometimes used to infer the shape of all images.
        def get_first_image():
            return next(iter(image_fields.values()))

        original_size_hw = (get_first_image().height, get_first_image().width)

        # Determine the target image resolution.
        if self._resolution is not None:
            resolution = self._resolution
        else:
            resolution = self._aspect_ratio_bucket_manager.get_aspect_ratio_bucket(Resolution.parse(original_size_hw))

        # Resize to cover the target resolution while preserving aspect ratio.
        for field_name, image in image_fields.items():
            image_fields[field_name] = resize_to_cover(image, resolution)

        # Apply cropping, and record top left crop position.
        if self._center_crop_enabled:
            top_left_y = max(0, (get_first_image().height - resolution.height) // 2)
            top_left_x = max(0, (get_first_image().width - resolution.width) // 2)
        else:
            crop_transform = transforms.RandomCrop(resolution.to_tuple())
            top_left_y, top_left_x, h, w = crop_transform.get_params(get_first_image(), resolution.to_tuple())
        for field_name, image in image_fields.items():
            image_fields[field_name] = crop(image, top_left_y, top_left_x, resolution.height, resolution.width)

        # Apply random flip and update top left crop position accordingly.
        # TODO(ryand): Use a seed for repeatable results.
        if self._random_flip_enabled and random.random() < 0.5:
            top_left_x = original_size_hw[1] - get_first_image().width - top_left_x
            for field_name, image in image_fields.items():
                image_fields[field_name] = self._flip_transform(image)

        crop_top_left_yx = (top_left_y, top_left_x)

        # Convert to Tensors.
        for field_name, image in image_fields.items():
            image_fields[field_name] = self._to_tensor_transform(image)

        # Normalize to range [-1.0, 1.0].
        # HACK(ryand): We should find a better way to determine the normalization range of each image field.
        for field_name, image in image_fields.items():
            if field_name in self._fields_to_normalize_to_range_minus_one_to_one:
                image_fields[field_name] = self._normalize_image_transform(image)

        data[self._orig_size_field_name] = original_size_hw
        data[self._crop_field_name] = crop_top_left_yx
        for field_name, image in image_fields.items():
            data[field_name] = image

        return data


================================================
FILE: src/invoke_training/_shared/data/transforms/shuffle_caption_transform.py
================================================
import typing

import numpy as np


class ShuffleCaptionTransform:
    """A transform that applies shuffle transformations to character-delimited captions.

    Example:
    - Original: "unreal engine, render of sci-fi helmet, dramatic lighting"
    - Shuffled: "render of sci-fi helmet, unreal engine, dramatic lighting"
    """

    def __init__(self, field_name: str, delimiter: str = ",", seed: int = 0):
        self._field_name = field_name
        self._delimiter = delimiter
        self._rng = np.random.default_rng(seed)

    def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
        caption: str = data[self._field_name]
        caption_chunks = caption.split(self._delimiter)
        caption_chunks = [s.strip() for s in caption_chunks]

        self._rng.shuffle(caption_chunks)

        join_str = self._delimiter + " "
        data[self._field_name] = join_str.join(caption_chunks)
        return data


================================================
FILE: src/invoke_training/_shared/data/transforms/template_caption_transform.py
================================================
import typing

import numpy as np


class TemplateCaptionTransform:
    """A simple transform that constructs a caption for each example by combining a caption template with the
    placeholder string.
    """

    def __init__(self, field_name: str, placeholder_str: str, caption_templates: list[str], seed: int = 0):
        self._field_name = field_name
        self._placeholder_str = placeholder_str
        self._caption_templates = caption_templates
        self._rng = np.random.default_rng(seed)

    def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
        caption = self._rng.choice(self._caption_templates).format(self._placeholder_str)
        # Assert that the template was well-formed such that the placeholder string is in the output caption.
        assert self._placeholder_str in caption

        data[self._field_name] = caption
        return data


================================================
FILE: src/invoke_training/_shared/data/transforms/tensor_disk_cache.py
================================================
import os
import typing

import torch


class TensorDiskCache:
    """A data cache that caches `torch.Tensor`s on disk."""

    def __init__(self, cache_dir: str):
        super().__init__()
        self._cache_dir = cache_dir

        os.makedirs(self._cache_dir, exist_ok=True)

    def _get_path(self, key: int):
        """Get the cache file path for `key`.
        Args:
            key (int): The cache key.
        Returns:
            str: The cache file path.
        """
        return os.path.join(self._cache_dir, f"{key}.pt")

    def save(self, key: int, data: typing.Dict[str, torch.Tensor]):
        """Save data in the cache.
        Raises:
            AssertionError: If an entry already exists in the cache for this `key`.
        Args:
            key (int): The cache key.
            data (typing.Dict[str, torch.Tensor]): The data to save.
        """
        # torch.save() supports a range of different data types, but it is cleaner if we force everyone to use a dict.
        # This allows for more reusable cache loading code.
        assert isinstance(data, dict)

        save_path = self._get_path(key)
        assert not os.path.exists(save_path)
        torch.save(data, save_path)

    def load(self, key: int) -> typing.Dict[str, torch.Tensor]:
        """Load data from the cache.
        Args:
            key (int): The cache key to load.
        Returns:
            typing.Dict[str, torch.Tensor]: Data loaded from the cache.
        """
        return torch.load(self._get_path(key))


================================================
FILE: src/invoke_training/_shared/data/utils/__init__.py
================================================


================================================
FILE: src/invoke_training/_shared/data/utils/aspect_ratio_bucket_manager.py
================================================
from invoke_training._shared.data.utils.resolution import Resolution


class AspectRatioBucketManager:
    def __init__(self, buckets: set[Resolution]):
        self.buckets = buckets

    @classmethod
    def from_constraints(cls, target_resolution: int, start_dim: int, end_dim: int, divisible_by: int) -> None:
        buckets = cls.build_aspect_ratio_buckets(
            target_resolution=target_resolution,
            start_dim=start_dim,
            end_dim=end_dim,
            divisible_by=divisible_by,
        )
        return cls(buckets)

    @classmethod
    def build_aspect_ratio_buckets(
        cls, target_resolution: int, start_dim: int, end_dim: int, divisible_by: int
    ) -> set[Resolution]:
        """Prepare a set of aspect ratios.

        Args:
            target_resolution (Resolution): All resolutions in the returned set will aim to have close to
                (but <=) `target_resolution * target_resolution` pixels.
            start_dim (int):
            end_dim (int):
            divisible_by (int): All dimensions in the returned set of resolutions will be divisible by `divisible_by`.

        Returns:
            set[tuple[int, int]]: The aspect ratio bucket resolutions.
        """
        # Validate target_resolution.
        assert target_resolution % divisible_by == 0

        # Validate start_dim, end_dim.
        assert start_dim <= end_dim
        assert start_dim % divisible_by == 0
        assert end_dim % divisible_by == 0

        target_size = target_resolution * target_resolution

        buckets = set()

        height = start_dim
        while height <= end_dim:
            width = (target_size // height) // divisible_by * divisible_by
            buckets.add(Resolution(height, width))
            buckets.add(Resolution(width, height))

            height += divisible_by

        return buckets

    def get_aspect_ratio_bucket(self, resolution: Resolution):
        """Get the bucket with the closest aspect ratio to 'resolution'."""
        # Note: If this is ever found to be a bottleneck, there is a clearly-more-efficient implementation using bisect.
        return min(self.buckets, key=lambda x: abs(x.aspect_ratio() - resolution.aspect_ratio()))


================================================
FILE: src/invoke_training/_shared/data/utils/resize.py
================================================
import math

from PIL.Image import Image
from torchvision import transforms

from invoke_training._shared.data.utils.resolution import Resolution


def resize_to_cover(image: Image, size_to_cover: Resolution) -> Image:
    """Resize image to the smallest size that covers 'size_to_cover' while preserving its aspect ratio.

    In other words, achieve the following:
    - resized_height >= size_to_cover.height
    - resized_width >= size_to_cover.width
    - resized_height == size_to_cover.height or resized_width == size_to_cover.width
    - 'image' aspect ratio is preserved.
    """

    scale_to_height = size_to_cover.height / image.height
    scale_to_width = size_to_cover.width / image.width

    if scale_to_height > scale_to_width:
        resize_height = size_to_cover.height
        resize_width = math.ceil(image.width * scale_to_height)
    else:
        resize_width = size_to_cover.width
        resize_height = math.ceil(image.height * scale_to_width)

    resize_transform = transforms.Resize(
        (resize_height, resize_width), interpolation=transforms.InterpolationMode.BILINEAR
    )

    return resize_transform(image)


================================================
FILE: src/invoke_training/_shared/data/utils/resolution.py
================================================
from typing import Union


class Resolution:
    def __init__(self, height: int, width: int):
        self.height = height
        self.width = width

    @classmethod
    def parse(cls, resolution: Union[int, tuple[int, int], "Resolution"]):
        """Initialize a Resolution object from another type."""
        if isinstance(resolution, int):
            # Assume square resolution.
            return cls(resolution, resolution)
        elif isinstance(resolution, tuple):
            height, width = resolution
            return cls(height, width)
        elif isinstance(resolution, cls):
            return cls(resolution.height, resolution.width)
        else:
            raise ValueError(f"Unsupported resolution type: '{type(resolution)}'.")

    def aspect_ratio(self):
        return self.height / self.width

    def to_tuple(self) -> tuple[int, int]:
        return (self.height, self.width)

    def __eq__(self, other: "Resolution") -> bool:
        return self.to_tuple() == other.to_tuple()

    def __lt__(self, other: "Resolution") -> bool:
        return self.to_tuple() < other.to_tuple()

    def __hash__(self):
        return hash(self.to_tuple())


================================================
FILE: src/invoke_training/_shared/flux/encoding_utils.py
================================================
import logging
from typing import List, Optional, Tuple, Union

import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast


def get_clip_prompt_embeds(
    prompt: Union[str, List[str]],
    tokenizer: CLIPTokenizer,
    text_encoder: CLIPTextModel,
    device: torch.device,
    num_images_per_prompt: int = 1,
    tokenizer_max_length: int = 77,
    logger: logging.Logger | None = None,
) -> torch.FloatTensor:
    """Encodes the prompt using CLIP text encoder and returns pooled embeddings."""
    prompt = [prompt] if isinstance(prompt, str) else prompt
    batch_size = len(prompt)

    # Process text input with the tokenizer
    text_inputs = tokenizer(
        prompt,
        padding="max_length",
        max_length=tokenizer_max_length,
        truncation=True,
        return_overflowing_tokens=False,
        return_length=False,
        return_tensors="pt",
    )

    text_input_ids = text_inputs.input_ids
    untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids

    # Check if truncation occurred
    if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
        removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1])
        if logger is not None:
            logger.warning(f"Warning: The following part of your input was truncated: {removed_text}")

    # Get prompt embeddings through the text encoder
    prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)

    # Use pooled output of CLIPTextModel
    prompt_embeds = prompt_embeds.pooler_output
    prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)

    # Duplicate text embeddings for each generation per prompt
    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)

    return prompt_embeds


def get_t5_prompt_embeds(
    prompt: Union[str, List[str]],
    tokenizer: T5TokenizerFast,
    text_encoder: T5EncoderModel,
    device: torch.device,
    num_images_per_prompt: int = 1,
    tokenizer_max_length: int = 512,
    logger: logging.Logger | None = None,
) -> torch.FloatTensor:
    """Encodes the prompt using T5 text encoder."""
    prompt = [prompt] if isinstance(prompt, str) else prompt
    batch_size = len(prompt)

    # Process text input with the tokenizer
    text_inputs = tokenizer(
        prompt,
        padding="max_length",
        max_length=tokenizer_max_length,
        truncation=True,
        return_length=False,
        return_overflowing_tokens=False,
        return_tensors="pt",
    )
    text_input_ids = text_inputs.input_ids
    untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids

    # Check if truncation occurred
    if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
        removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1])
        if logger is not None:
            logger.warning(f"Warning: The following part of your input was truncated: {removed_text}")

    # Get prompt embeddings through the text encoder
    prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)[0]

    dtype = text_encoder.dtype
    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)

    # Get shape and duplicate for multiple generations
    _, seq_len, _ = prompt_embeds.shape
    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

    return prompt_embeds


def handle_lora_scale(
    clip_text_encoder: CLIPTextModel,
    t5_text_encoder: T5EncoderModel,
    lora_scale: Optional[float] = None,
    use_peft_backend: bool = False,
):
    """Handles LoRA scale adjustments for text encoders."""
    if lora_scale is not None and use_peft_backend:
        from peft.utils import scale_lora_layers

        # Apply LoRA scaling to text encoders if they exist
        if clip_text_encoder is not None:
            scale_lora_layers(clip_text_encoder, lora_scale)
        if t5_text_encoder is not None:
            scale_lora_layers(t5_text_encoder, lora_scale)

        return True
    return False


def reset_lora_scale(
    clip_text_encoder: CLIPTextModel,
    t5_text_encoder: T5EncoderModel,
    lora_scale: Optional[float] = None,
    lora_applied: bool = False,
    use_peft_backend: bool = False,
):
    """Resets LoRA scale for text encoders if it was applied."""
    if lora_applied and use_peft_backend:
        from peft.utils import unscale_lora_layers

        # Reset LoRA scaling
        if clip_text_encoder is not None:
            unscale_lora_layers(clip_text_encoder, lora_scale)
        if t5_text_encoder is not None:
            unscale_lora_layers(t5_text_encoder, lora_scale)


# A lot of this code was adapted from:
# https://github.com/huggingface/diffusers/blob/ea81a4228d8ff16042c3ccaf61f0e588e60166cd/src/diffusers/pipelines/flux/pipeline_flux.py#L310-L387
def encode_prompt(
    prompt: Union[str, List[str]],
    prompt_2: Optional[Union[str, List[str]]],
    clip_tokenizer: CLIPTokenizer,
    t5_tokenizer: T5TokenizerFast,
    clip_text_encoder: CLIPTextModel,
    t5_text_encoder: T5EncoderModel,
    device: torch.device,
    num_images_per_prompt: int = 1,
    prompt_embeds: Optional[torch.FloatTensor] = None,
    pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
    lora_scale: Optional[float] = None,
    use_peft_backend: bool = False,
    clip_tokenizer_max_length: int = 77,
    t5_tokenizer_max_length: int = 512,
    logger: logging.Logger | None = None,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
    """
    Encodes the prompt using both CLIP and T5 text encoders.

    Returns:
        Tuple containing:
            - T5 text embeddings
            - CLIP pooled embeddings
            - Text IDs
    """
    # Apply LoRA scale if needed
    lora_applied = handle_lora_scale(
        clip_text_encoder=clip_text_encoder,
        t5_text_encoder=t5_text_encoder,
        lora_scale=lora_scale,
        use_peft_backend=use_peft_backend,
    )

    # If no pre-generated embeddings, create them
    if prompt_embeds is None:
        prompt_2 = prompt_2 or prompt
        prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2

        # Get CLIP pooled embeddings
        pooled_prompt_embeds = get_clip_prompt_embeds(
            prompt=prompt,
            tokenizer=clip_tokenizer,
            text_encoder=clip_text_encoder,
            device=device,
            num_images_per_prompt=num_images_per_prompt,
            tokenizer_max_length=clip_tokenizer_max_length,
        )

        # Get T5 text embeddings
        prompt_embeds = get_t5_prompt_embeds(
            prompt=prompt_2,
            tokenizer=t5_tokenizer,
            text_encoder=t5_text_encoder,
            device=device,
            num_images_per_prompt=num_images_per_prompt,
            tokenizer_max_length=t5_tokenizer_max_length,
        )

    # Reset LoRA scale if it was applied
    reset_lora_scale(
        clip_text_encoder=clip_text_encoder,
        t5_text_encoder=t5_text_encoder,
        lora_scale=lora_scale,
        lora_applied=lora_applied,
        use_peft_backend=use_peft_backend,
    )

    # Create text_ids placeholder for model
    dtype = clip_text_encoder.dtype if clip_text_encoder is not None else t5_text_encoder.dtype
    text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)

    return prompt_embeds, pooled_prompt_embeds, text_ids


================================================
FILE: src/invoke_training/_shared/flux/lora_checkpoint_utils.py
================================================
# ruff: noqa: N806
import os
from pathlib import Path

import peft
import torch
from diffusers import FluxTransformer2DModel
from transformers import CLIPTextModel

from invoke_training._shared.checkpoints.lora_checkpoint_utils import (
    _convert_peft_state_dict_to_kohya_state_dict,
    load_multi_model_peft_checkpoint,
    save_multi_model_peft_checkpoint,
)
from invoke_training._shared.checkpoints.serialization import save_state_dict

FLUX_TRANSFORMER_TARGET_MODULES = [
    # double blocks
    "attn.add_k_proj",
    "attn.add_q_proj",
    "attn.add_v_proj",
    "attn.to_add_out",
    "attn.to_k",
    "attn.to_q",
    "attn.to_v",
    "attn.to_out.0",
    "ff.net.0.proj",
    "ff.net.2.0",
    "ff_context.net.0.proj",
    "ff_context.net.2.0",
    # single blocks
    "attn.to_k",
    "attn.to_q",
    "attn.to_v",
    "proj_mlp",
    "proj_out",
    "proj_in",
]

TEXT_ENCODER_TARGET_MODULES = ["fc1", "fc2", "q_proj", "k_proj", "v_proj", "out_proj"]

# Module lists copied from diffusers training script.
# These module lists will produce lighter, less expressive, LoRA models than the non-light versions.
FLUX_TRANSFORMER_TARGET_MODULES_LIGHT = ["to_k", "to_q", "to_v", "to_out.0"]
FLUX_TEXT_ENCODER_TARGET_MODULES_LIGHT = ["q_proj", "k_proj", "v_proj", "out_proj"]

FLUX_PEFT_TRANSFORMER_KEY = "transformer"
FLUX_PEFT_TEXT_ENCODER_1_KEY = "text_encoder_1"
FLUX_PEFT_TEXT_ENCODER_2_KEY = "text_encoder_2"

FLUX_KOHYA_TRANSFORMER_KEY = "lora_unet"
FLUX_KOHYA_TEXT_ENCODER_1_KEY = "lora_clip"
FLUX_KOHYA_TEXT_ENCODER_2_KEY = "lora_t5"

FLUX_PEFT_TO_KOHYA_KEYS = {
    FLUX_PEFT_TRANSFORMER_KEY: FLUX_KOHYA_TRANSFORMER_KEY,
    FLUX_PEFT_TEXT_ENCODER_1_KEY: FLUX_KOHYA_TEXT_ENCODER_1_KEY,
    FLUX_PEFT_TEXT_ENCODER_2_KEY: FLUX_KOHYA_TEXT_ENCODER_2_KEY,
}


def save_flux_peft_checkpoint(
    checkpoint_dir: Path | str,
    transformer: peft.PeftModel | None,
    text_encoder_1: peft.PeftModel | None,
    text_encoder_2: peft.PeftModel | None,
):
    models = {}
    if transformer is not None:
        models[FLUX_PEFT_TRANSFORMER_KEY] = transformer
    if text_encoder_1 is not None:
        models[FLUX_PEFT_TEXT_ENCODER_1_KEY] = text_encoder_1
    if text_encoder_2 is not None:
        models[FLUX_PEFT_TEXT_ENCODER_2_KEY] = text_encoder_2

    save_multi_model_peft_checkpoint(checkpoint_dir=checkpoint_dir, models=models)


def load_flux_peft_checkpoint(
    checkpoint_dir: Path | str,
    transformer: FluxTransformer2DModel,
    text_encoder_1: CLIPTextModel,
    text_encoder_2: CLIPTextModel,
    is_trainable: bool = False,
):
    models = load_multi_model_peft_checkpoint(
        checkpoint_dir=checkpoint_dir,
        models={
            FLUX_PEFT_TRANSFORMER_KEY: transformer,
            FLUX_PEFT_TEXT_ENCODER_1_KEY: text_encoder_1,
            FLUX_PEFT_TEXT_ENCODER_2_KEY: text_encoder_2,
        },
        is_trainable=is_trainable,
        raise_if_subdir_missing=False,
    )

    return models[FLUX_PEFT_TRANSFORMER_KEY], models[FLUX_PEFT_TEXT_ENCODER_1_KEY], models[FLUX_PEFT_TEXT_ENCODER_2_KEY]


def save_flux_kohya_checkpoint(
    checkpoint_path: Path,
    transformer: peft.PeftModel | None,
    text_encoder_1: peft.PeftModel | None,
    text_encoder_2: peft.PeftModel | None,
):
    kohya_prefixes = []
    models = []
    for kohya_prefix, peft_model in zip(
        [FLUX_KOHYA_TRANSFORMER_KEY, FLUX_KOHYA_TEXT_ENCODER_1_KEY], [transformer, text_encoder_1]
    ):
        if peft_model is not None:
            kohya_prefixes.append(kohya_prefix)
            models.append(peft_model)

    kohya_state_dict = _convert_peft_models_to_kohya_state_dict(kohya_prefixes=kohya_prefixes, models=models)

    checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
    save_state_dict(kohya_state_dict, checkpoint_path)


def convert_flux_peft_checkpoint_to_kohya_state_dict(
    in_checkpoint_dir: Path,
    out_checkpoint_file: Path,
    dtype: torch.dtype = torch.float32,
) -> dict[str, torch.Tensor]:
    """Convert Flux PEFT models to a Kohya-format LoRA state dict."""
    # Get the immediate subdirectories of the checkpoint directory. We assume that each subdirectory is a PEFT model.
    peft_model_dirs = os.listdir(in_checkpoint_dir)
    peft_model_dirs = [in_checkpoint_dir / d for d in peft_model_dirs]  # Convert to Path objects.
    peft_model_dirs = [d for d in peft_model_dirs if d.is_dir()]  # Filter out non-directories.

    if len(peft_model_dirs) == 0:
        raise ValueError(f"No checkpoint files found in directory '{in_checkpoint_dir}'.")

    kohya_state_dict = {}
    for peft_model_dir in peft_model_dirs:
        if peft_model_dir.name in FLUX_PEFT_TO_KOHYA_KEYS:
            kohya_prefix = FLUX_PEFT_TO_KOHYA_KEYS[peft_model_dir.name]
        else:
            raise ValueError(f"Unrecognized checkpoint directory: '{peft_model_dir}'.")

        # Note: This logic to load the LoraConfig and weights directly is based on how it is done here:
        # https://github.com/huggingface/peft/blob/8665e2b5719faa4e4b91749ddec09442927b53e0/src/peft/peft_model.py#L672-L689
        # This may need to be updated in the future to support other adapter types (LoKr, LoHa, etc.).
        # Also, I could see this interface breaking in the future.
        lora_config = peft.LoraConfig.from_pretrained(peft_model_dir)
        lora_weights = peft.utils.load_peft_weights(peft_model_dir, device="cpu")

        kohya_state_dict.update(
            _convert_peft_state_dict_to_kohya_state_dict(
                lora_config=lora_config, peft_state_dict=lora_weights, prefix=kohya_prefix, dtype=dtype
            )
        )

    save_state_dict(kohya_state_dict, out_checkpoint_file)


def _convert_peft_models_to_kohya_state_dict(
    kohya_prefixes: list[str], models: list[peft.PeftModel]
) -> dict[str, torch.Tensor]:
    kohya_state_dict = {}
    default_adapter_name = "default"

    for kohya_prefix, peft_model in zip(kohya_prefixes, models, strict=True):
        lora_config = peft_model.peft_config[default_adapter_name]
        assert isinstance(lora_config, peft.LoraConfig)

        state_dict = peft.get_peft_model_state_dict(peft_model, adapter_name=default_adapter_name)

        if kohya_prefix == FLUX_KOHYA_TRANSFORMER_KEY:
            state_dict = convert_diffusers_to_flux_transformer_checkpoint(state_dict)

        kohya_state_dict.update(
            _convert_peft_state_dict_to_kohya_state_dict(
                lora_config=lora_config,
                peft_state_dict=state_dict,
                prefix=kohya_prefix,
                dtype=torch.float32,
            )
        )

    return kohya_state_dict


def find_matching_key_prefix(state_dict, key_pattern):
    """
    Find if any key in the state dictionary matches the given pattern.

    Args:
        state_dict: The state dictionary to search in
        key_pattern: The pattern to look for in keys

    Returns:
        The matching prefix if found, False otherwise
    """
    base_prefix = key_pattern.split(".lora_A")[0].split(".lora_B")[0].split(".weight")[0]

    for key in state_dict.keys():
        if base_prefix in key:
            return base_prefix
    return False


def convert_layer_weights(target_dict, source_dict, source_pattern, target_pattern):
    """
    Convert weights from source pattern to target pattern if they exist.

    Args:
        target_dict: Dictionary to store converted weights
        source_dict: Source dictionary containing weights
        source_pattern: Original key pattern to search for
        target_pattern: New key pattern to use


    Returns:
        Tuple of (updated target_dict, updated source_dict)
    """
    if original_key := find_matching_key_prefix(source_dict, source_pattern):
        # Find all keys matching the pattern
        keys_to_convert = [k for k in source_dict.keys() if original_key in k]

        for 
Download .txt
gitextract_5kosyax7/

├── .github/
│   └── workflows/
│       ├── deploy.yaml
│       └── test.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── docs/
│   ├── contributing/
│   │   ├── development_environment.md
│   │   ├── directory_structure.md
│   │   ├── documentation.md
│   │   └── tests.md
│   ├── get-started/
│   │   ├── installation.md
│   │   └── quick-start.md
│   ├── guides/
│   │   ├── dataset_formats.md
│   │   ├── model_merge.md
│   │   └── stable_diffusion/
│   │       ├── dpo_lora_sd.md
│   │       ├── gnome_lora_masks_sdxl.md
│   │       ├── robocats_finetune_sdxl.md
│   │       └── textual_inversion_sdxl.md
│   ├── index.md
│   ├── reference/
│   │   └── config/
│   │       ├── index.md
│   │       ├── pipelines/
│   │       │   ├── sd_lora.md
│   │       │   ├── sd_textual_inversion.md
│   │       │   ├── sdxl_finetune.md
│   │       │   ├── sdxl_lora.md
│   │       │   ├── sdxl_lora_and_textual_inversion.md
│   │       │   └── sdxl_textual_inversion.md
│   │       └── shared/
│   │           ├── data/
│   │           │   ├── data_loader_config.md
│   │           │   └── dataset_config.md
│   │           └── optimizer_config.md
│   └── templates/
│       └── python/
│           └── material/
│               └── labels.html
├── mkdocs.yml
├── pyproject.toml
├── sample_data/
│   └── bruce_the_gnome/
│       └── data.jsonl
├── src/
│   └── invoke_training/
│       ├── __init__.py
│       ├── _shared/
│       │   ├── __init__.py
│       │   ├── accelerator/
│       │   │   ├── __init__.py
│       │   │   └── accelerator_utils.py
│       │   ├── checkpoints/
│       │   │   ├── __init__.py
│       │   │   ├── checkpoint_tracker.py
│       │   │   ├── lora_checkpoint_utils.py
│       │   │   └── serialization.py
│       │   ├── data/
│       │   │   ├── ARCHITECTURE.md
│       │   │   ├── __init__.py
│       │   │   ├── data_loaders/
│       │   │   │   ├── __init__.py
│       │   │   │   ├── dreambooth_sd_dataloader.py
│       │   │   │   ├── image_caption_flux_dataloader.py
│       │   │   │   ├── image_caption_sd_dataloader.py
│       │   │   │   ├── image_pair_preference_sd_dataloader.py
│       │   │   │   └── textual_inversion_sd_dataloader.py
│       │   │   ├── datasets/
│       │   │   │   ├── __init__.py
│       │   │   │   ├── build_dataset.py
│       │   │   │   ├── hf_image_caption_dataset.py
│       │   │   │   ├── hf_image_pair_preference_dataset.py
│       │   │   │   ├── image_caption_dir_dataset.py
│       │   │   │   ├── image_caption_jsonl_dataset.py
│       │   │   │   ├── image_dir_dataset.py
│       │   │   │   ├── image_pair_preference_dataset.py
│       │   │   │   └── transform_dataset.py
│       │   │   ├── samplers/
│       │   │   │   ├── __init__.py
│       │   │   │   ├── aspect_ratio_bucket_batch_sampler.py
│       │   │   │   ├── batch_offset_sampler.py
│       │   │   │   ├── concat_sampler.py
│       │   │   │   ├── interleaved_sampler.py
│       │   │   │   └── offset_sampler.py
│       │   │   ├── transforms/
│       │   │   │   ├── __init__.py
│       │   │   │   ├── caption_prefix_transform.py
│       │   │   │   ├── concat_fields_transform.py
│       │   │   │   ├── constant_field_transform.py
│       │   │   │   ├── drop_field_transform.py
│       │   │   │   ├── flux_image_transform.py
│       │   │   │   ├── load_cache_transform.py
│       │   │   │   ├── sd_image_transform.py
│       │   │   │   ├── shuffle_caption_transform.py
│       │   │   │   ├── template_caption_transform.py
│       │   │   │   └── tensor_disk_cache.py
│       │   │   └── utils/
│       │   │       ├── __init__.py
│       │   │       ├── aspect_ratio_bucket_manager.py
│       │   │       ├── resize.py
│       │   │       └── resolution.py
│       │   ├── flux/
│       │   │   ├── encoding_utils.py
│       │   │   ├── lora_checkpoint_utils.py
│       │   │   ├── model_loading_utils.py
│       │   │   └── validation.py
│       │   ├── optimizer/
│       │   │   ├── __init__.py
│       │   │   └── optimizer_utils.py
│       │   ├── stable_diffusion/
│       │   │   ├── __init__.py
│       │   │   ├── base_model_version.py
│       │   │   ├── checkpoint_utils.py
│       │   │   ├── lora_checkpoint_utils.py
│       │   │   ├── min_snr_weighting.py
│       │   │   ├── model_loading_utils.py
│       │   │   ├── textual_inversion.py
│       │   │   ├── tokenize_captions.py
│       │   │   └── validation.py
│       │   ├── tools/
│       │   │   ├── __init__.py
│       │   │   └── generate_images.py
│       │   └── utils/
│       │       ├── import_xformers.py
│       │       └── jsonl.py
│       ├── config/
│       │   ├── __init__.py
│       │   ├── base_pipeline_config.py
│       │   ├── config_base_model.py
│       │   ├── data/
│       │   │   ├── __init__.py
│       │   │   ├── data_loader_config.py
│       │   │   └── dataset_config.py
│       │   ├── optimizer/
│       │   │   ├── __init__.py
│       │   │   └── optimizer_config.py
│       │   └── pipeline_config.py
│       ├── model_merge/
│       │   ├── __init__.py
│       │   ├── extract_lora.py
│       │   ├── merge_models.py
│       │   ├── merge_tasks_to_base.py
│       │   ├── scripts/
│       │   │   ├── extract_lora_from_model_diff.py
│       │   │   ├── merge_lora_into_model.py
│       │   │   ├── merge_models.py
│       │   │   └── merge_task_models_to_base_model.py
│       │   └── utils/
│       │       ├── __init__.py
│       │       ├── normalize_weights.py
│       │       └── parse_model_arg.py
│       ├── pipelines/
│       │   ├── __init__.py
│       │   ├── _experimental/
│       │   │   └── sd_dpo_lora/
│       │   │       ├── config.py
│       │   │       └── train.py
│       │   ├── callbacks.py
│       │   ├── flux/
│       │   │   └── lora/
│       │   │       ├── __init__.py
│       │   │       ├── config.py
│       │   │       └── train.py
│       │   ├── invoke_train.py
│       │   ├── stable_diffusion/
│       │   │   ├── __init__.py
│       │   │   ├── lora/
│       │   │   │   ├── __init__.py
│       │   │   │   ├── config.py
│       │   │   │   └── train.py
│       │   │   └── textual_inversion/
│       │   │       ├── __init__.py
│       │   │       ├── config.py
│       │   │       └── train.py
│       │   └── stable_diffusion_xl/
│       │       ├── __init__.py
│       │       ├── finetune/
│       │       │   ├── __init__.py
│       │       │   ├── config.py
│       │       │   └── train.py
│       │       ├── lora/
│       │       │   ├── __init__.py
│       │       │   ├── config.py
│       │       │   └── train.py
│       │       ├── lora_and_textual_inversion/
│       │       │   ├── __init__.py
│       │       │   ├── config.py
│       │       │   └── train.py
│       │       └── textual_inversion/
│       │           ├── __init__.py
│       │           ├── config.py
│       │           └── train.py
│       ├── sample_configs/
│       │   ├── _experimental/
│       │   │   ├── sd_dpo_lora_pickapic_1x24gb.yaml
│       │   │   └── sd_dpo_lora_refinement_pokemon_1x24gb.yaml
│       │   ├── flux_lora_1x40gb.yaml
│       │   ├── sd_lora_baroque_1x8gb.yaml
│       │   ├── sd_textual_inversion_gnome_1x8gb.yaml
│       │   ├── sdxl_finetune_baroque_1x24gb.yaml
│       │   ├── sdxl_finetune_robocats_1x24gb.yaml
│       │   ├── sdxl_lora_and_ti_gnome_1x24gb.yaml
│       │   ├── sdxl_lora_baroque_1x24gb.yaml
│       │   ├── sdxl_lora_baroque_1x8gb.yaml
│       │   ├── sdxl_lora_masks_gnome_1x24gb.yaml
│       │   ├── sdxl_textual_inversion_gnome_1x24gb.yaml
│       │   └── sdxl_textual_inversion_masks_gnome_1x24gb.yaml
│       ├── scripts/
│       │   ├── __init__.py
│       │   ├── _experimental/
│       │   │   ├── auto_caption/
│       │   │   │   └── auto_caption_images.py
│       │   │   ├── masks/
│       │   │   │   ├── clipseg.py
│       │   │   │   ├── generate_masks.py
│       │   │   │   └── generate_masks_for_jsonl_dataset.py
│       │   │   └── rank_images.py
│       │   ├── convert_sd_lora_to_kohya_format.py
│       │   ├── invoke_generate_images.py
│       │   ├── invoke_train.py
│       │   ├── invoke_train_ui.py
│       │   ├── invoke_visualize_data_loading.py
│       │   └── utils/
│       │       └── image_dir_dataset.py
│       └── ui/
│           ├── __init__.py
│           ├── app.py
│           ├── config_groups/
│           │   ├── __init__.py
│           │   ├── aspect_ratio_bucket_config_group.py
│           │   ├── base_pipeline_config_group.py
│           │   ├── dataset_config_group.py
│           │   ├── flux_lora_config_group.py
│           │   ├── image_caption_sd_data_loader_config_group.py
│           │   ├── optimizer_config_group.py
│           │   ├── sd_lora_config_group.py
│           │   ├── sd_textual_inversion_config_group.py
│           │   ├── sdxl_finetune_config_group.py
│           │   ├── sdxl_lora_and_textual_inversion_config_group.py
│           │   ├── sdxl_lora_config_group.py
│           │   ├── sdxl_textual_inversion_config_group.py
│           │   ├── textual_inversion_sd_data_loader_config_group.py
│           │   └── ui_config_element.py
│           ├── gradio_blocks/
│           │   ├── header.py
│           │   └── pipeline_tab.py
│           ├── index.html
│           ├── pages/
│           │   ├── data_page.py
│           │   └── training_page.py
│           └── utils/
│               ├── prompts.py
│               └── utils.py
└── tests/
    └── invoke_training/
        ├── _shared/
        │   ├── __init__.py
        │   ├── checkpoints/
        │   │   ├── test_checkpoint_tracker.py
        │   │   └── test_serialization.py
        │   ├── data/
        │   │   ├── __init__.py
        │   │   ├── data_loaders/
        │   │   │   ├── __init__.py
        │   │   │   ├── test_dreambooth_sd_dataloader.py
        │   │   │   ├── test_image_caption_sd_dataloader.py
        │   │   │   ├── test_image_pair_preference_sd_dataloader.py
        │   │   │   └── test_textual_inversion_sd_dataloader.py
        │   │   ├── dataset_fixtures.py
        │   │   ├── datasets/
        │   │   │   ├── __init__.py
        │   │   │   ├── test_hf_image_caption_dataset.py
        │   │   │   ├── test_hf_image_pair_preference_dataset.py
        │   │   │   ├── test_image_caption_dir_dataset.py
        │   │   │   ├── test_image_caption_jsonl_dataset.py
        │   │   │   ├── test_image_dir_dataset.py
        │   │   │   ├── test_image_pair_preference_dataset.py
        │   │   │   └── test_transform_dataset.py
        │   │   ├── samplers/
        │   │   │   ├── __init__.py
        │   │   │   ├── test_aspect_ratio_bucket_batch_sampler.py
        │   │   │   ├── test_batch_offset_sampler.py
        │   │   │   ├── test_concat_sampler.py
        │   │   │   ├── test_interleaved_sampler.py
        │   │   │   └── test_offset_sampler.py
        │   │   ├── transforms/
        │   │   │   ├── __init__.py
        │   │   │   ├── test_caption_prefix_transform.py
        │   │   │   ├── test_concat_fields_transform.py
        │   │   │   ├── test_constant_field_transform.py
        │   │   │   ├── test_drop_field_transform.py
        │   │   │   ├── test_load_cache_transform.py
        │   │   │   ├── test_sd_image_transform.py
        │   │   │   ├── test_shuffle_caption_transform.py
        │   │   │   ├── test_template_caption_transform.py
        │   │   │   └── test_tensor_disk_cache.py
        │   │   └── utils/
        │   │       ├── __init__.py
        │   │       ├── test_aspect_ratio_bucket_manager.py
        │   │       ├── test_resize.py
        │   │       └── test_resolution.py
        │   ├── stable_diffusion/
        │   │   ├── __init__.py
        │   │   ├── test_base_model_version.py
        │   │   ├── test_lora_checkpoint_utils.py
        │   │   ├── test_model_loading_utils.py
        │   │   ├── test_textual_inversion.py
        │   │   └── ti_embedding_checkpoint_fixture.py
        │   └── utils/
        │       └── test_jsonl.py
        ├── config/
        │   └── pipelines/
        │       └── test_pipeline_config.py
        ├── model_merge/
        │   ├── __init__.py
        │   ├── test_merge_models.py
        │   ├── test_merge_tasks_to_base.py
        │   └── utils.py
        └── ui/
            └── utils/
                └── test_prompts.py
Download .txt
SYMBOL INDEX (581 symbols across 157 files)

FILE: src/invoke_training/_shared/accelerator/accelerator_utils.py
  function initialize_accelerator (line 14) | def initialize_accelerator(
  function initialize_logging (line 40) | def initialize_logging(logger_name: str, accelerator: Accelerator) -> Mu...
  function get_mixed_precision_dtype (line 71) | def get_mixed_precision_dtype(accelerator: Accelerator):
  function get_dtype_from_str (line 95) | def get_dtype_from_str(dtype_str: Literal["float16", "bfloat16", "float3...

FILE: src/invoke_training/_shared/checkpoints/checkpoint_tracker.py
  class CheckpointTracker (line 6) | class CheckpointTracker:
    method __init__ (line 14) | def __init__(
    method prune (line 48) | def prune(self, buffer_num: int = 1) -> int:
    method get_path (line 83) | def get_path(self, epoch: int, step: int) -> str:

FILE: src/invoke_training/_shared/checkpoints/lora_checkpoint_utils.py
  function save_multi_model_peft_checkpoint (line 7) | def save_multi_model_peft_checkpoint(checkpoint_dir: Path | str, models:...
  function load_multi_model_peft_checkpoint (line 30) | def load_multi_model_peft_checkpoint(
  function _convert_peft_state_dict_to_kohya_state_dict (line 57) | def _convert_peft_state_dict_to_kohya_state_dict(
  function _convert_peft_models_to_kohya_state_dict (line 79) | def _convert_peft_models_to_kohya_state_dict(

FILE: src/invoke_training/_shared/checkpoints/serialization.py
  function save_state_dict (line 8) | def save_state_dict(state_dict: typing.Dict[str, torch.Tensor], out_file...
  function load_state_dict (line 33) | def load_state_dict(in_file: typing.Union[Path, str]) -> typing.Dict[str...

FILE: src/invoke_training/_shared/data/data_loaders/dreambooth_sd_dataloader.py
  function build_dreambooth_sd_dataloader (line 25) | def build_dreambooth_sd_dataloader(

FILE: src/invoke_training/_shared/data/data_loaders/image_caption_flux_dataloader.py
  function build_image_caption_flux_dataloader (line 33) | def build_image_caption_flux_dataloader(  # noqa: C901

FILE: src/invoke_training/_shared/data/data_loaders/image_caption_sd_dataloader.py
  function sd_image_caption_collate_fn (line 27) | def sd_image_caption_collate_fn(examples):
  function build_aspect_ratio_bucket_manager (line 64) | def build_aspect_ratio_bucket_manager(config: AspectRatioBucketConfig):
  function build_image_caption_sd_dataloader (line 73) | def build_image_caption_sd_dataloader(  # noqa: C901

FILE: src/invoke_training/_shared/data/data_loaders/image_pair_preference_sd_dataloader.py
  function sd_image_pair_preference_collate_fn (line 15) | def sd_image_pair_preference_collate_fn(examples):
  function build_image_pair_preference_sd_dataloader (line 49) | def build_image_pair_preference_sd_dataloader(

FILE: src/invoke_training/_shared/data/data_loaders/textual_inversion_sd_dataloader.py
  function get_preset_ti_caption_templates (line 33) | def get_preset_ti_caption_templates(preset: Literal["object", "style"]) ...
  function build_textual_inversion_sd_dataloader (line 97) | def build_textual_inversion_sd_dataloader(  # noqa: C901

FILE: src/invoke_training/_shared/data/datasets/build_dataset.py
  function build_hf_hub_image_caption_dataset (line 15) | def build_hf_hub_image_caption_dataset(config: HFHubImageCaptionDatasetC...
  function build_image_caption_jsonl_dataset (line 27) | def build_image_caption_jsonl_dataset(config: ImageCaptionJsonlDatasetCo...
  function build_image_caption_dir_dataset (line 36) | def build_image_caption_dir_dataset(config: ImageCaptionDirDatasetConfig...
  function build_hf_image_pair_preference_dataset (line 43) | def build_hf_image_pair_preference_dataset(

FILE: src/invoke_training/_shared/data/datasets/hf_image_caption_dataset.py
  class HFImageCaptionDataset (line 11) | class HFImageCaptionDataset(torch.utils.data.Dataset):
    method __init__ (line 18) | def __init__(self, hf_dataset, image_column: str = "image", caption_co...
    method from_dir (line 42) | def from_dir(
    method from_hub (line 67) | def from_hub(
    method get_image_dimensions (line 87) | def get_image_dimensions(self) -> list[Resolution]:
    method __len__ (line 101) | def __len__(self) -> int:
    method __getitem__ (line 109) | def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:

FILE: src/invoke_training/_shared/data/datasets/hf_image_pair_preference_dataset.py
  class HFImagePairPreferenceDataset (line 9) | class HFImagePairPreferenceDataset(torch.utils.data.Dataset):
    method __init__ (line 16) | def __init__(
    method from_hub (line 77) | def from_hub(
    method __len__ (line 101) | def __len__(self) -> int:
    method __getitem__ (line 109) | def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:

FILE: src/invoke_training/_shared/data/datasets/image_caption_dir_dataset.py
  class ImageCaptionDirDataset (line 10) | class ImageCaptionDirDataset(torch.utils.data.Dataset):
    method __init__ (line 13) | def __init__(
    method _load_image (line 64) | def _load_image(self, image_path: str) -> Image.Image:
    method get_image_dimensions (line 69) | def get_image_dimensions(self) -> list[Resolution]:
    method __len__ (line 83) | def __len__(self) -> int:
    method __getitem__ (line 86) | def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:

FILE: src/invoke_training/_shared/data/datasets/image_caption_jsonl_dataset.py
  class ImageCaptionExample (line 16) | class ImageCaptionExample(BaseModel):
  class ImageCaptionJsonlDataset (line 22) | class ImageCaptionJsonlDataset(torch.utils.data.Dataset):
    method __init__ (line 25) | def __init__(
    method save_jsonl (line 55) | def save_jsonl(self):
    method _get_image_path (line 67) | def _get_image_path(self, idx: int) -> str:
    method _get_mask_path (line 77) | def _get_mask_path(self, idx: int) -> str:
    method _load_image (line 87) | def _load_image(self, image_path: str) -> Image.Image:
    method _load_mask (line 92) | def _load_mask(self, mask_path: str) -> Image.Image:
    method _load_example (line 95) | def _load_example(self, idx: int) -> dict[str, typing.Any]:
    method get_image_dimensions (line 105) | def get_image_dimensions(self) -> list[Resolution]:
    method __len__ (line 118) | def __len__(self) -> int:
    method __getitem__ (line 121) | def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:

FILE: src/invoke_training/_shared/data/datasets/image_dir_dataset.py
  class ImageDirDataset (line 10) | class ImageDirDataset(torch.utils.data.Dataset):
    method __init__ (line 13) | def __init__(
    method _load_image (line 49) | def _load_image(self, image_path: str) -> Image.Image:
    method get_image_dimensions (line 54) | def get_image_dimensions(self) -> list[Resolution]:
    method __len__ (line 68) | def __len__(self) -> int:
    method __getitem__ (line 71) | def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:

FILE: src/invoke_training/_shared/data/datasets/image_pair_preference_dataset.py
  class ImagePairPreferenceDataset (line 11) | class ImagePairPreferenceDataset(torch.utils.data.Dataset):
    method __init__ (line 12) | def __init__(self, dataset_dir: str):
    method save_metadata (line 19) | def save_metadata(
    method __len__ (line 27) | def __len__(self) -> int:
    method __getitem__ (line 30) | def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:

FILE: src/invoke_training/_shared/data/datasets/transform_dataset.py
  class TransformDataset (line 11) | class TransformDataset(torch.utils.data.Dataset):
    method __init__ (line 14) | def __init__(self, base_dataset: torch.utils.data.Dataset, transforms:...
    method __len__ (line 19) | def __len__(self) -> int:
    method __getitem__ (line 22) | def __getitem__(self, idx: int) -> DataType:

FILE: src/invoke_training/_shared/data/samplers/aspect_ratio_bucket_batch_sampler.py
  class AspectRatioBucketBatchSampler (line 15) | class AspectRatioBucketBatchSampler(Sampler[list[int]]):
    method __init__ (line 18) | def __init__(
    method __str__ (line 34) | def __str__(self) -> str:
    method from_image_sizes (line 44) | def from_image_sizes(
    method _build_bucket_to_index_map (line 57) | def _build_bucket_to_index_map(
    method get_buckets (line 73) | def get_buckets(self) -> AspectRatioBuckets:
    method __iter__ (line 76) | def __iter__(self) -> Iterator[list[int]]:
    method __len__ (line 103) | def __len__(self) -> int:
  function log_aspect_ratio_buckets (line 110) | def log_aspect_ratio_buckets(logger: logging.Logger, batch_sampler: Aspe...

FILE: src/invoke_training/_shared/data/samplers/batch_offset_sampler.py
  class BatchOffsetSampler (line 6) | class BatchOffsetSampler(Sampler[int]):
    method __init__ (line 9) | def __init__(self, sampler: Sampler[int], offset: int):
    method __iter__ (line 13) | def __iter__(self) -> typing.Iterator[int]:
    method __len__ (line 18) | def __len__(self) -> int:

FILE: src/invoke_training/_shared/data/samplers/concat_sampler.py
  class ConcatSampler (line 9) | class ConcatSampler(Sampler[T_co]):
    method __init__ (line 19) | def __init__(self, samplers: list[Sampler[T_co] | typing.Iterable[T_co...
    method __iter__ (line 22) | def __iter__(self) -> typing.Iterator[T_co]:
    method __len__ (line 25) | def __len__(self) -> int:

FILE: src/invoke_training/_shared/data/samplers/interleaved_sampler.py
  class InterleavedSampler (line 8) | class InterleavedSampler(Sampler[T_co]):
    method __init__ (line 21) | def __init__(self, samplers: list[Sampler[T_co] | typing.Iterable[T_co...
    method __iter__ (line 25) | def __iter__(self) -> typing.Iterator[T_co]:
    method __len__ (line 38) | def __len__(self) -> int:

FILE: src/invoke_training/_shared/data/samplers/offset_sampler.py
  class OffsetSampler (line 6) | class OffsetSampler(Sampler[int]):
    method __init__ (line 9) | def __init__(self, sampler: Sampler[int], offset: int):
    method __iter__ (line 13) | def __iter__(self) -> typing.Iterator[int]:
    method __len__ (line 17) | def __len__(self) -> int:

FILE: src/invoke_training/_shared/data/transforms/caption_prefix_transform.py
  class CaptionPrefixTransform (line 4) | class CaptionPrefixTransform:
    method __init__ (line 7) | def __init__(self, caption_field_name: str, prefix: str):
    method __call__ (line 11) | def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[...

FILE: src/invoke_training/_shared/data/transforms/concat_fields_transform.py
  class ConcatFieldsTransform (line 4) | class ConcatFieldsTransform:
    method __init__ (line 7) | def __init__(self, src_field_names: list[str], dst_field_name: str, se...
    method __call__ (line 12) | def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[...

FILE: src/invoke_training/_shared/data/transforms/constant_field_transform.py
  class ConstantFieldTransform (line 4) | class ConstantFieldTransform:
    method __init__ (line 7) | def __init__(self, field_name: str, field_value: typing.Any):
    method __call__ (line 11) | def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[...

FILE: src/invoke_training/_shared/data/transforms/drop_field_transform.py
  class DropFieldTransform (line 4) | class DropFieldTransform:
    method __init__ (line 7) | def __init__(self, field_to_drop: str):
    method __call__ (line 10) | def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[...

FILE: src/invoke_training/_shared/data/transforms/flux_image_transform.py
  class FluxImageTransform (line 10) | class FluxImageTransform:
    method __init__ (line 13) | def __init__(
    method __call__ (line 42) | def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[...

FILE: src/invoke_training/_shared/data/transforms/load_cache_transform.py
  class LoadCacheTransform (line 6) | class LoadCacheTransform:
    method __init__ (line 9) | def __init__(
    method __call__ (line 24) | def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[...

FILE: src/invoke_training/_shared/data/transforms/sd_image_transform.py
  class SDImageTransform (line 11) | class SDImageTransform:
    method __init__ (line 14) | def __init__(
    method __call__ (line 59) | def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[...

FILE: src/invoke_training/_shared/data/transforms/shuffle_caption_transform.py
  class ShuffleCaptionTransform (line 6) | class ShuffleCaptionTransform:
    method __init__ (line 14) | def __init__(self, field_name: str, delimiter: str = ",", seed: int = 0):
    method __call__ (line 19) | def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[...

FILE: src/invoke_training/_shared/data/transforms/template_caption_transform.py
  class TemplateCaptionTransform (line 6) | class TemplateCaptionTransform:
    method __init__ (line 11) | def __init__(self, field_name: str, placeholder_str: str, caption_temp...
    method __call__ (line 17) | def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[...

FILE: src/invoke_training/_shared/data/transforms/tensor_disk_cache.py
  class TensorDiskCache (line 7) | class TensorDiskCache:
    method __init__ (line 10) | def __init__(self, cache_dir: str):
    method _get_path (line 16) | def _get_path(self, key: int):
    method save (line 25) | def save(self, key: int, data: typing.Dict[str, torch.Tensor]):
    method load (line 41) | def load(self, key: int) -> typing.Dict[str, torch.Tensor]:

FILE: src/invoke_training/_shared/data/utils/aspect_ratio_bucket_manager.py
  class AspectRatioBucketManager (line 4) | class AspectRatioBucketManager:
    method __init__ (line 5) | def __init__(self, buckets: set[Resolution]):
    method from_constraints (line 9) | def from_constraints(cls, target_resolution: int, start_dim: int, end_...
    method build_aspect_ratio_buckets (line 19) | def build_aspect_ratio_buckets(
    method get_aspect_ratio_bucket (line 56) | def get_aspect_ratio_bucket(self, resolution: Resolution):

FILE: src/invoke_training/_shared/data/utils/resize.py
  function resize_to_cover (line 9) | def resize_to_cover(image: Image, size_to_cover: Resolution) -> Image:

FILE: src/invoke_training/_shared/data/utils/resolution.py
  class Resolution (line 4) | class Resolution:
    method __init__ (line 5) | def __init__(self, height: int, width: int):
    method parse (line 10) | def parse(cls, resolution: Union[int, tuple[int, int], "Resolution"]):
    method aspect_ratio (line 23) | def aspect_ratio(self):
    method to_tuple (line 26) | def to_tuple(self) -> tuple[int, int]:
    method __eq__ (line 29) | def __eq__(self, other: "Resolution") -> bool:
    method __lt__ (line 32) | def __lt__(self, other: "Resolution") -> bool:
    method __hash__ (line 35) | def __hash__(self):

FILE: src/invoke_training/_shared/flux/encoding_utils.py
  function get_clip_prompt_embeds (line 8) | def get_clip_prompt_embeds(
  function get_t5_prompt_embeds (line 55) | def get_t5_prompt_embeds(
  function handle_lora_scale (line 101) | def handle_lora_scale(
  function reset_lora_scale (line 121) | def reset_lora_scale(
  function encode_prompt (line 141) | def encode_prompt(

FILE: src/invoke_training/_shared/flux/lora_checkpoint_utils.py
  function save_flux_peft_checkpoint (line 62) | def save_flux_peft_checkpoint(
  function load_flux_peft_checkpoint (line 79) | def load_flux_peft_checkpoint(
  function save_flux_kohya_checkpoint (line 100) | def save_flux_kohya_checkpoint(
  function convert_flux_peft_checkpoint_to_kohya_state_dict (line 121) | def convert_flux_peft_checkpoint_to_kohya_state_dict(
  function _convert_peft_models_to_kohya_state_dict (line 158) | def _convert_peft_models_to_kohya_state_dict(
  function find_matching_key_prefix (line 185) | def find_matching_key_prefix(state_dict, key_pattern):
  function convert_layer_weights (line 204) | def convert_layer_weights(target_dict, source_dict, source_pattern, targ...
  function convert_double_transformer_block (line 231) | def convert_double_transformer_block(target_dict, source_dict, prefix=""...
  function convert_single_transformer_block (line 333) | def convert_single_transformer_block(target_dict, source_dict, prefix, b...
  function convert_embedding_layers (line 384) | def convert_embedding_layers(target_dict, source_dict, prefix, has_guida...
  function convert_output_layers (line 436) | def convert_output_layers(target_dict, source_dict, prefix):
  function convert_diffusers_to_flux_transformer_checkpoint (line 460) | def convert_diffusers_to_flux_transformer_checkpoint(

FILE: src/invoke_training/_shared/flux/model_loading_utils.py
  class PipelineVersionEnum (line 9) | class PipelineVersionEnum(Enum):
  function load_pipeline (line 13) | def load_pipeline(
  function load_models_flux (line 65) | def load_models_flux(

FILE: src/invoke_training/_shared/flux/validation.py
  function generate_validation_images_flux (line 25) | def generate_validation_images_flux(  # noqa: C901

FILE: src/invoke_training/_shared/optimizer/optimizer_utils.py
  function initialize_optimizer (line 7) | def initialize_optimizer(

FILE: src/invoke_training/_shared/stable_diffusion/base_model_version.py
  class BaseModelVersionEnum (line 6) | class BaseModelVersionEnum(Enum):
  function get_base_model_version (line 13) | def get_base_model_version(
  function check_base_model_version (line 54) | def check_base_model_version(

FILE: src/invoke_training/_shared/stable_diffusion/checkpoint_utils.py
  function save_sdxl_diffusers_unet_checkpoint (line 8) | def save_sdxl_diffusers_unet_checkpoint(
  function save_sdxl_diffusers_checkpoint (line 25) | def save_sdxl_diffusers_checkpoint(

FILE: src/invoke_training/_shared/stable_diffusion/lora_checkpoint_utils.py
  function save_sd_peft_checkpoint (line 66) | def save_sd_peft_checkpoint(
  function load_sd_peft_checkpoint (line 78) | def load_sd_peft_checkpoint(
  function save_sdxl_peft_checkpoint (line 91) | def save_sdxl_peft_checkpoint(
  function load_sdxl_peft_checkpoint (line 108) | def load_sdxl_peft_checkpoint(
  function save_sd_kohya_checkpoint (line 129) | def save_sd_kohya_checkpoint(checkpoint_path: Path, unet: peft.PeftModel...
  function save_sdxl_kohya_checkpoint (line 143) | def save_sdxl_kohya_checkpoint(
  function convert_sd_peft_checkpoint_to_kohya_state_dict (line 165) | def convert_sd_peft_checkpoint_to_kohya_state_dict(

FILE: src/invoke_training/_shared/stable_diffusion/min_snr_weighting.py
  function compute_snr (line 5) | def compute_snr(noise_scheduler: DDPMScheduler, timesteps: torch.Tensor):

FILE: src/invoke_training/_shared/stable_diffusion/model_loading_utils.py
  class PipelineVersionEnum (line 21) | class PipelineVersionEnum(Enum):
  function load_pipeline (line 26) | def load_pipeline(
  function from_pretrained_with_variant_fallback (line 75) | def from_pretrained_with_variant_fallback(
  function load_models_sd (line 113) | def load_models_sd(
  function load_models_sdxl (line 167) | def load_models_sdxl(

FILE: src/invoke_training/_shared/stable_diffusion/textual_inversion.py
  function _expand_placeholder_token (line 10) | def _expand_placeholder_token(placeholder_token: str, num_vectors: int =...
  function _add_tokens_to_tokenizer (line 23) | def _add_tokens_to_tokenizer(placeholder_tokens: list[str], tokenizer: P...
  function expand_placeholders_in_caption (line 37) | def expand_placeholders_in_caption(caption: str, tokenizer: CLIPTokenize...
  function initialize_placeholder_tokens_from_initializer_token (line 68) | def initialize_placeholder_tokens_from_initializer_token(
  function initialize_placeholder_tokens_from_initial_phrase (line 106) | def initialize_placeholder_tokens_from_initial_phrase(
  function initialize_placeholder_tokens_from_initial_embedding (line 131) | def initialize_placeholder_tokens_from_initial_embedding(
  function restore_original_embeddings (line 172) | def restore_original_embeddings(

FILE: src/invoke_training/_shared/stable_diffusion/tokenize_captions.py
  function tokenize_captions (line 7) | def tokenize_captions(tokenizer: CLIPTokenizer, captions: list[str]) -> ...

FILE: src/invoke_training/_shared/stable_diffusion/validation.py
  function generate_validation_images_sd (line 24) | def generate_validation_images_sd(  # noqa: C901
  function generate_validation_images_sdxl (line 140) | def generate_validation_images_sdxl(  # noqa: C901

FILE: src/invoke_training/_shared/tools/generate_images.py
  function generate_images (line 15) | def generate_images(

FILE: src/invoke_training/_shared/utils/import_xformers.py
  function import_xformers (line 1) | def import_xformers():

FILE: src/invoke_training/_shared/utils/jsonl.py
  function load_jsonl (line 6) | def load_jsonl(jsonl_path: Path | str) -> list[Any]:
  function save_jsonl (line 15) | def save_jsonl(data: list[Any], jsonl_path: Path | str) -> None:

FILE: src/invoke_training/config/base_pipeline_config.py
  class BasePipelineConfig (line 7) | class BasePipelineConfig(ConfigBaseModel):

FILE: src/invoke_training/config/config_base_model.py
  class ConfigBaseModel (line 4) | class ConfigBaseModel(BaseModel):

FILE: src/invoke_training/config/data/data_loader_config.py
  class AspectRatioBucketConfig (line 10) | class AspectRatioBucketConfig(ConfigBaseModel):
  class ImageCaptionSDDataLoaderConfig (line 42) | class ImageCaptionSDDataLoaderConfig(ConfigBaseModel):
  class ImageCaptionFluxDataLoaderConfig (line 73) | class ImageCaptionFluxDataLoaderConfig(ConfigBaseModel):
  class DreamboothSDDataLoaderConfig (line 104) | class DreamboothSDDataLoaderConfig(ConfigBaseModel):
  class TextualInversionSDDataLoaderConfig (line 142) | class TextualInversionSDDataLoaderConfig(ConfigBaseModel):

FILE: src/invoke_training/config/data/dataset_config.py
  class HFHubImageCaptionDatasetConfig (line 8) | class HFHubImageCaptionDatasetConfig(ConfigBaseModel):
  class ImageCaptionJsonlDatasetConfig (line 33) | class ImageCaptionJsonlDatasetConfig(ConfigBaseModel):
  class ImageDirDatasetConfig (line 54) | class ImageDirDatasetConfig(ConfigBaseModel):
  class ImageCaptionDirDatasetConfig (line 67) | class ImageCaptionDirDatasetConfig(ConfigBaseModel):

FILE: src/invoke_training/config/optimizer/optimizer_config.py
  class AdamOptimizerConfig (line 6) | class AdamOptimizerConfig(ConfigBaseModel):
  class ProdigyOptimizerConfig (line 26) | class ProdigyOptimizerConfig(ConfigBaseModel):

FILE: src/invoke_training/model_merge/extract_lora.py
  function get_patched_base_weights_from_peft_model (line 10) | def get_patched_base_weights_from_peft_model(peft_model: PeftModel) -> d...
  function get_state_dict_diff (line 29) | def get_state_dict_diff(
  function extract_lora_from_diffs (line 37) | def extract_lora_from_diffs(

FILE: src/invoke_training/model_merge/merge_models.py
  function merge_models (line 10) | def merge_models(
  function lerp (line 53) | def lerp(a: torch.Tensor, b: torch.Tensor, weight_a: float) -> torch.Ten...
  function slerp (line 58) | def slerp(a: torch.Tensor, b: torch.Tensor, weight_a: float, dot_product...

FILE: src/invoke_training/model_merge/merge_tasks_to_base.py
  function merge_tasks_to_base_model (line 9) | def merge_tasks_to_base_model(

FILE: src/invoke_training/model_merge/scripts/extract_lora_from_model_diff.py
  class StableDiffusionModel (line 38) | class StableDiffusionModel:
    method all_none (line 45) | def all_none(self) -> bool:
  function load_model (line 49) | def load_model(
  function str_to_device (line 111) | def str_to_device(device_str: Literal["cuda", "cpu"]) -> torch.device:
  function state_dict_to_device (line 120) | def state_dict_to_device(state_dict: dict[str, torch.Tensor], device: to...
  function extract_lora_from_submodel (line 124) | def extract_lora_from_submodel(
  function extract_lora (line 187) | def extract_lora(
  function main (line 257) | def main():

FILE: src/invoke_training/model_merge/scripts/merge_lora_into_model.py
  function to_invokeai_base_model_type (line 24) | def to_invokeai_base_model_type(model_type: PipelineVersionEnum):
  function merge_lora_into_sd_model (line 34) | def merge_lora_into_sd_model(
  function parse_lora_model_arg (line 102) | def parse_lora_model_arg(lora_model_arg: str) -> tuple[str, float]:
  function main (line 113) | def main():

FILE: src/invoke_training/model_merge/scripts/merge_models.py
  class MergeModel (line 16) | class MergeModel:
  function run_merge_models (line 22) | def run_merge_models(
  function parse_model_args (line 73) | def parse_model_args(models: list[str], weights: list[str]) -> list[Merg...
  function main (line 85) | def main():

FILE: src/invoke_training/model_merge/scripts/merge_task_models_to_base_model.py
  function run_merge_models (line 14) | def run_merge_models(
  function main (line 91) | def main():

FILE: src/invoke_training/model_merge/utils/normalize_weights.py
  function normalize_weights (line 1) | def normalize_weights(weights: list[float]) -> list[float]:

FILE: src/invoke_training/model_merge/utils/parse_model_arg.py
  function parse_model_arg (line 1) | def parse_model_arg(model: str, delimiter: str = "::") -> tuple[str, str...

FILE: src/invoke_training/pipelines/_experimental/sd_dpo_lora/config.py
  class HFHubImagePairPreferenceDatasetConfig (line 10) | class HFHubImagePairPreferenceDatasetConfig(ConfigBaseModel):
  class ImagePairPreferenceDatasetConfig (line 16) | class ImagePairPreferenceDatasetConfig(ConfigBaseModel):
  class ImagePairPreferenceSDDataLoaderConfig (line 23) | class ImagePairPreferenceSDDataLoaderConfig(ConfigBaseModel):
  class SdDirectPreferenceOptimizationLoraConfig (line 50) | class SdDirectPreferenceOptimizationLoraConfig(BasePipelineConfig):
    method check_validation_prompts (line 242) | def check_validation_prompts(self):

FILE: src/invoke_training/pipelines/_experimental/sd_dpo_lora/train.py
  function _save_sd_lora_checkpoint (line 47) | def _save_sd_lora_checkpoint(
  function train_forward_dpo (line 70) | def train_forward_dpo(  # noqa: C901
  function train (line 191) | def train(config: SdDirectPreferenceOptimizationLoraConfig, callbacks: l...

FILE: src/invoke_training/pipelines/callbacks.py
  class ModelType (line 5) | class ModelType(Enum):
  class ModelCheckpoint (line 37) | class ModelCheckpoint:
    method __init__ (line 40) | def __init__(self, file_path: str, model_type: ModelType):
  class TrainingCheckpoint (line 45) | class TrainingCheckpoint:
    method __init__ (line 50) | def __init__(self, models: list[ModelCheckpoint], epoch: int, step: int):
  class ValidationImage (line 56) | class ValidationImage:
    method __init__ (line 57) | def __init__(self, file_path: str, prompt: str, image_idx: int):
  class ValidationImages (line 71) | class ValidationImages:
    method __init__ (line 72) | def __init__(self, images: list[ValidationImage], epoch: int, step: int):
  class PipelineCallbacks (line 85) | class PipelineCallbacks(ABC):
    method on_save_checkpoint (line 86) | def on_save_checkpoint(self, checkpoint: TrainingCheckpoint):
    method on_save_validation_images (line 89) | def on_save_validation_images(self, images: ValidationImages):

FILE: src/invoke_training/pipelines/flux/lora/config.py
  class FluxLoraConfig (line 17) | class FluxLoraConfig(BasePipelineConfig):

FILE: src/invoke_training/pipelines/flux/lora/train.py
  function _save_flux_lora_checkpoint (line 47) | def _save_flux_lora_checkpoint(
  function _build_data_loader (line 86) | def _build_data_loader(
  function cache_text_encoder_outputs (line 109) | def cache_text_encoder_outputs(
  function cache_vae_outputs (line 137) | def cache_vae_outputs(cache_dir: str, data_loader: DataLoader, vae: Auto...
  function get_sigmas (line 156) | def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch....
  function get_noisy_latents (line 168) | def get_noisy_latents(noise_scheduler: FlowMatchEulerDiscreteScheduler, ...
  function decode_latents (line 209) | def decode_latents(vae: AutoencoderKL, latents: torch.Tensor):
  function train_forward (line 222) | def train_forward(  # noqa: C901
  function train (line 304) | def train(config: FluxLoraConfig, callbacks: list[PipelineCallbacks] | N...

FILE: src/invoke_training/pipelines/invoke_train.py
  function train (line 17) | def train(config: PipelineConfig, callbacks: list[PipelineCallbacks] | N...

FILE: src/invoke_training/pipelines/stable_diffusion/lora/config.py
  class SdLoraConfig (line 14) | class SdLoraConfig(BasePipelineConfig):
    method check_validation_prompts (line 224) | def check_validation_prompts(self):

FILE: src/invoke_training/pipelines/stable_diffusion/lora/train.py
  function _save_sd_lora_checkpoint (line 46) | def _save_sd_lora_checkpoint(
  function _build_data_loader (line 80) | def _build_data_loader(
  function cache_text_encoder_outputs (line 115) | def cache_text_encoder_outputs(
  function cache_vae_outputs (line 143) | def cache_vae_outputs(cache_dir: str, data_loader: DataLoader, vae: Auto...
  function train_forward (line 162) | def train_forward(  # noqa: C901
  function train (line 267) | def train(config: SdLoraConfig, callbacks: list[PipelineCallbacks] | Non...

FILE: src/invoke_training/pipelines/stable_diffusion/textual_inversion/config.py
  class SdTextualInversionConfig (line 10) | class SdTextualInversionConfig(BasePipelineConfig):
    method check_validation_prompts (line 198) | def check_validation_prompts(self):

FILE: src/invoke_training/pipelines/stable_diffusion/textual_inversion/train.py
  function _save_ti_embeddings (line 41) | def _save_ti_embeddings(
  function _initialize_placeholder_tokens (line 80) | def _initialize_placeholder_tokens(
  function train (line 138) | def train(config: SdTextualInversionConfig, callbacks: list[PipelineCall...

FILE: src/invoke_training/pipelines/stable_diffusion_xl/finetune/config.py
  class SdxlFinetuneConfig (line 10) | class SdxlFinetuneConfig(BasePipelineConfig):
    method check_validation_prompts (line 166) | def check_validation_prompts(self):

FILE: src/invoke_training/pipelines/stable_diffusion_xl/finetune/train.py
  function _save_sdxl_checkpoint (line 44) | def _save_sdxl_checkpoint(
  function train (line 96) | def train(config: SdxlFinetuneConfig, callbacks: list[PipelineCallbacks]...

FILE: src/invoke_training/pipelines/stable_diffusion_xl/lora/config.py
  class SdxlLoraConfig (line 14) | class SdxlLoraConfig(BasePipelineConfig):
    method check_validation_prompts (line 230) | def check_validation_prompts(self):

FILE: src/invoke_training/pipelines/stable_diffusion_xl/lora/train.py
  function _save_sdxl_lora_checkpoint (line 49) | def _save_sdxl_lora_checkpoint(
  function _build_data_loader (line 88) | def _build_data_loader(
  function _encode_prompt (line 131) | def _encode_prompt(text_encoders: list[CLIPPreTrainedModel], prompt_toke...
  function cache_text_encoder_outputs (line 158) | def cache_text_encoder_outputs(
  function train_forward (line 200) | def train_forward(  # noqa: C901
  function train (line 335) | def train(config: SdxlLoraConfig, callbacks: list[PipelineCallbacks] | N...

FILE: src/invoke_training/pipelines/stable_diffusion_xl/lora_and_textual_inversion/config.py
  class SdxlLoraAndTextualInversionConfig (line 10) | class SdxlLoraAndTextualInversionConfig(BasePipelineConfig):
    method check_validation_prompts (line 233) | def check_validation_prompts(self):

FILE: src/invoke_training/pipelines/stable_diffusion_xl/lora_and_textual_inversion/train.py
  function _save_sdxl_lora_and_ti_checkpoint (line 50) | def _save_sdxl_lora_and_ti_checkpoint(
  function train (line 118) | def train(config: SdxlLoraAndTextualInversionConfig, callbacks: list[Pip...

FILE: src/invoke_training/pipelines/stable_diffusion_xl/textual_inversion/config.py
  class SdxlTextualInversionConfig (line 10) | class SdxlTextualInversionConfig(BasePipelineConfig):
    method check_validation_prompts (line 200) | def check_validation_prompts(self):

FILE: src/invoke_training/pipelines/stable_diffusion_xl/textual_inversion/train.py
  function _save_ti_embeddings (line 44) | def _save_ti_embeddings(
  function _initialize_placeholder_tokens (line 93) | def _initialize_placeholder_tokens(
  function train (line 164) | def train(config: SdxlTextualInversionConfig, callbacks: list[PipelineCa...

FILE: src/invoke_training/scripts/_experimental/auto_caption/auto_caption_images.py
  function select_device_and_dtype (line 14) | def select_device_and_dtype(force_cpu: bool = False) -> tuple[torch.devi...
  function process_images (line 24) | def process_images(images: list[Image.Image], prompt: str, moondream, to...
  function main (line 35) | def main(

FILE: src/invoke_training/scripts/_experimental/masks/clipseg.py
  function load_clipseg_model (line 6) | def load_clipseg_model() -> tuple[CLIPSegProcessor, CLIPSegForImageSegme...
  function run_clipseg (line 13) | def run_clipseg(
  function select_device (line 58) | def select_device() -> torch.device:

FILE: src/invoke_training/scripts/_experimental/masks/generate_masks.py
  function generate_masks (line 13) | def generate_masks(image_dir: str, prompt: str, clipseg_temp: float, bat...
  function main (line 53) | def main():

FILE: src/invoke_training/scripts/_experimental/masks/generate_masks_for_jsonl_dataset.py
  function collate_fn (line 16) | def collate_fn(examples):
  function validate_out_json_path (line 24) | def validate_out_json_path(out_json_path: str | Path):
  function generate_masks (line 33) | def generate_masks(
  function main (line 92) | def main():

FILE: src/invoke_training/scripts/_experimental/rank_images.py
  function parse_args (line 15) | def parse_args():
  function clip (line 28) | def clip(val, min_val, max_val):
  function main (line 32) | def main():

FILE: src/invoke_training/scripts/convert_sd_lora_to_kohya_format.py
  function parse_args (line 11) | def parse_args():
  function main (line 36) | def main():

FILE: src/invoke_training/scripts/invoke_generate_images.py
  function parse_args (line 8) | def parse_args():
  function parse_lora_args (line 95) | def parse_lora_args(lora_args: list[str] | None) -> list[tuple[Path, int]]:
  function parse_prompt_file (line 113) | def parse_prompt_file(prompt_file: str) -> list[str]:
  function main (line 120) | def main():

FILE: src/invoke_training/scripts/invoke_train.py
  function parse_args (line 11) | def parse_args():
  function main (line 23) | def main():

FILE: src/invoke_training/scripts/invoke_train_ui.py
  function main (line 8) | def main():

FILE: src/invoke_training/scripts/invoke_visualize_data_loading.py
  function save_image (line 26) | def save_image(torch_image: torch.Tensor, out_path: Path):
  function parse_args (line 49) | def parse_args():
  function visualize (line 62) | def visualize(data_loader: DataLoader):
  function main (line 89) | def main():

FILE: src/invoke_training/scripts/utils/image_dir_dataset.py
  class ImageDirDataset (line 8) | class ImageDirDataset(torch.utils.data.Dataset):
    method __init__ (line 11) | def __init__(
    method _load_image (line 29) | def _load_image(self, image_path: str) -> Image.Image:
    method __len__ (line 34) | def __len__(self) -> int:
    method __getitem__ (line 37) | def __getitem__(self, idx: int):
  function list_collate_fn (line 43) | def list_collate_fn(examples):

FILE: src/invoke_training/ui/app.py
  function build_app (line 12) | def build_app():

FILE: src/invoke_training/ui/config_groups/aspect_ratio_bucket_config_group.py
  class AspectRatioBucketConfigGroup (line 9) | class AspectRatioBucketConfigGroup(UIConfigElement):
    method __init__ (line 10) | def __init__(self):
    method update_ui_components_with_config_data (line 23) | def update_ui_components_with_config_data(
    method update_config_with_ui_component_data (line 41) | def update_config_with_ui_component_data(

FILE: src/invoke_training/ui/config_groups/base_pipeline_config_group.py
  class BasePipelineConfigGroup (line 10) | class BasePipelineConfigGroup(UIConfigElement):
    method __init__ (line 11) | def __init__(self):
    method update_ui_components_with_config_data (line 55) | def update_ui_components_with_config_data(self, config: BasePipelineCo...
    method update_config_with_ui_component_data (line 94) | def update_config_with_ui_component_data(

FILE: src/invoke_training/ui/config_groups/dataset_config_group.py
  class HFHubImageCaptionDatasetConfigGroup (line 22) | class HFHubImageCaptionDatasetConfigGroup(UIConfigElement):
    method __init__ (line 23) | def __init__(self):
    method update_ui_components_with_config_data (line 43) | def update_ui_components_with_config_data(
    method update_config_with_ui_component_data (line 54) | def update_config_with_ui_component_data(
  class ImageCaptionJsonlDatasetConfigGroup (line 70) | class ImageCaptionJsonlDatasetConfigGroup(UIConfigElement):
    method __init__ (line 71) | def __init__(self):
    method update_ui_components_with_config_data (line 90) | def update_ui_components_with_config_data(
    method update_config_with_ui_component_data (line 104) | def update_config_with_ui_component_data(
  class ImageCaptionDirDatasetConfigGroup (line 119) | class ImageCaptionDirDatasetConfigGroup(UIConfigElement):
    method __init__ (line 120) | def __init__(self):
    method update_ui_components_with_config_data (line 133) | def update_ui_components_with_config_data(
    method update_config_with_ui_component_data (line 141) | def update_config_with_ui_component_data(
  class ImageDirDatasetConfigGroup (line 153) | class ImageDirDatasetConfigGroup(UIConfigElement):
    method __init__ (line 154) | def __init__(self):
    method update_ui_components_with_config_data (line 167) | def update_ui_components_with_config_data(
    method update_config_with_ui_component_data (line 175) | def update_config_with_ui_component_data(
  class DatasetConfigGroup (line 187) | class DatasetConfigGroup(UIConfigElement):
    method __init__ (line 188) | def __init__(self, allowed_types: list[str]):
    method _on_type_change (line 224) | def _on_type_change(self, type: str):
    method update_ui_components_with_config_data (line 232) | def update_ui_components_with_config_data(
    method update_config_with_ui_component_data (line 270) | def update_config_with_ui_component_data(

FILE: src/invoke_training/ui/config_groups/flux_lora_config_group.py
  class FluxLoraConfigGroup (line 15) | class FluxLoraConfigGroup(UIConfigElement):
    method __init__ (line 16) | def __init__(self):
    method get_ui_output_components (line 185) | def get_ui_output_components(self) -> list[gr.components.Component]:
    method update_ui_components_with_config_data (line 224) | def update_ui_components_with_config_data(
    method update_config_with_ui_component_data (line 285) | def update_config_with_ui_component_data(  # noqa: C901

FILE: src/invoke_training/ui/config_groups/image_caption_sd_data_loader_config_group.py
  class ImageCaptionSDDataLoaderConfigGroup (line 11) | class ImageCaptionSDDataLoaderConfigGroup(UIConfigElement):
    method __init__ (line 12) | def __init__(self):
    method update_ui_components_with_config_data (line 63) | def update_ui_components_with_config_data(
    method update_config_with_ui_component_data (line 81) | def update_config_with_ui_component_data(

FILE: src/invoke_training/ui/config_groups/optimizer_config_group.py
  class AdamOptimizerConfigGroup (line 11) | class AdamOptimizerConfigGroup(UIConfigElement):
    method __init__ (line 12) | def __init__(self):
    method update_ui_components_with_config_data (line 34) | def update_ui_components_with_config_data(self, config: AdamOptimizerC...
    method update_config_with_ui_component_data (line 44) | def update_config_with_ui_component_data(
  class ProdigyOptimizerConfigGroup (line 59) | class ProdigyOptimizerConfigGroup(UIConfigElement):
    method __init__ (line 60) | def __init__(self):
    method update_ui_components_with_config_data (line 77) | def update_ui_components_with_config_data(
    method update_config_with_ui_component_data (line 87) | def update_config_with_ui_component_data(
  class OptimizerConfigGroup (line 100) | class OptimizerConfigGroup(UIConfigElement):
    method __init__ (line 101) | def __init__(self):
    method _on_optimizer_type_change (line 119) | def _on_optimizer_type_change(self, optimizer_type: str):
    method update_ui_components_with_config_data (line 125) | def update_ui_components_with_config_data(self, config: OptimizerConfi...
    method update_config_with_ui_component_data (line 145) | def update_config_with_ui_component_data(self, orig_config: OptimizerC...

FILE: src/invoke_training/ui/config_groups/sd_lora_config_group.py
  class SdLoraConfigGroup (line 19) | class SdLoraConfigGroup(UIConfigElement):
    method __init__ (line 20) | def __init__(self):
    method update_ui_components_with_config_data (line 180) | def update_ui_components_with_config_data(self, config: SdLoraConfig) ...
    method update_config_with_ui_component_data (line 218) | def update_config_with_ui_component_data(

FILE: src/invoke_training/ui/config_groups/sd_textual_inversion_config_group.py
  class SdTextualInversionConfigGroup (line 19) | class SdTextualInversionConfigGroup(UIConfigElement):
    method __init__ (line 20) | def __init__(self):
    method update_ui_components_with_config_data (line 180) | def update_ui_components_with_config_data(
    method update_config_with_ui_component_data (line 218) | def update_config_with_ui_component_data(

FILE: src/invoke_training/ui/config_groups/sdxl_finetune_config_group.py
  class SdxlFinetuneConfigGroup (line 19) | class SdxlFinetuneConfigGroup(UIConfigElement):
    method __init__ (line 20) | def __init__(self):
    method update_ui_components_with_config_data (line 177) | def update_ui_components_with_config_data(
    method update_config_with_ui_component_data (line 215) | def update_config_with_ui_component_data(

FILE: src/invoke_training/ui/config_groups/sdxl_lora_and_textual_inversion_config_group.py
  class SdxlLoraAndTextualInversionConfigGroup (line 21) | class SdxlLoraAndTextualInversionConfigGroup(UIConfigElement):
    method __init__ (line 22) | def __init__(self):
    method update_ui_components_with_config_data (line 225) | def update_ui_components_with_config_data(
    method update_config_with_ui_component_data (line 273) | def update_config_with_ui_component_data(

FILE: src/invoke_training/ui/config_groups/sdxl_lora_config_group.py
  class SdxlLoraConfigGroup (line 19) | class SdxlLoraConfigGroup(UIConfigElement):
    method __init__ (line 20) | def __init__(self):
    method update_ui_components_with_config_data (line 186) | def update_ui_components_with_config_data(
    method update_config_with_ui_component_data (line 227) | def update_config_with_ui_component_data(

FILE: src/invoke_training/ui/config_groups/sdxl_textual_inversion_config_group.py
  class SdxlTextualInversionConfigGroup (line 19) | class SdxlTextualInversionConfigGroup(UIConfigElement):
    method __init__ (line 20) | def __init__(self):
    method update_ui_components_with_config_data (line 186) | def update_ui_components_with_config_data(
    method update_config_with_ui_component_data (line 225) | def update_config_with_ui_component_data(

FILE: src/invoke_training/ui/config_groups/textual_inversion_sd_data_loader_config_group.py
  class TextualInversionSDDataLoaderConfigGroup (line 13) | class TextualInversionSDDataLoaderConfigGroup(UIConfigElement):
    method __init__ (line 14) | def __init__(self):
    method update_ui_components_with_config_data (line 88) | def update_ui_components_with_config_data(
    method update_config_with_ui_component_data (line 114) | def update_config_with_ui_component_data(

FILE: src/invoke_training/ui/config_groups/ui_config_element.py
  class UIConfigElement (line 6) | class UIConfigElement:
    method get_ui_output_components (line 9) | def get_ui_output_components(self) -> list[gr.components.Component]:
    method get_ui_input_components (line 19) | def get_ui_input_components(self) -> list[gr.components.Component]:
    method update_ui_components_with_config_data (line 29) | def update_ui_components_with_config_data(self, config) -> dict[gr.com...
    method update_config_with_ui_component_data (line 33) | def update_config_with_ui_component_data(self, orig_config, ui_data: d...

FILE: src/invoke_training/ui/gradio_blocks/header.py
  class Header (line 6) | class Header:
    method __init__ (line 7) | def __init__(self):

FILE: src/invoke_training/ui/gradio_blocks/pipeline_tab.py
  class PipelineTab (line 11) | class PipelineTab:
    method __init__ (line 12) | def __init__(
    method on_reset_config_button_click (line 102) | def on_reset_config_button_click(self, file_path: str):
    method on_generate_config_button_click (line 129) | def on_generate_config_button_click(self, data: dict):
    method on_run_training_button_click (line 162) | def on_run_training_button_click(self):

FILE: src/invoke_training/ui/pages/data_page.py
  class DataPage (line 18) | class DataPage:
    method __init__ (line 19) | def __init__(self):
    method _update_state (line 190) | def _update_state(self, idx: int):
    method _on_load_existing_dataset_button_click (line 234) | def _on_load_existing_dataset_button_click(self, data: dict):
    method _on_create_dataset_button_click (line 249) | def _on_create_dataset_button_click(self, data: dict):
    method _on_change_dataset_button_click (line 269) | def _on_change_dataset_button_click(self):
    method _on_save_and_go_button_click (line 274) | def _on_save_and_go_button_click(self, data: dict, idx_change: int):
    method _on_save_and_next_button_click (line 288) | def _on_save_and_next_button_click(self, data: dict):
    method _on_save_and_prev_button_click (line 291) | def _on_save_and_prev_button_click(self, data: dict):
    method _on_cur_example_index_change (line 294) | def _on_cur_example_index_change(self, data: dict):
    method _on_add_images_button_click (line 297) | def _on_add_images_button_click(self, data: dict):
    method app (line 336) | def app(self):

FILE: src/invoke_training/ui/pages/training_page.py
  class TrainingPage (line 33) | class TrainingPage:
    method __init__ (line 34) | def __init__(self):
    method app (line 152) | def app(self):
    method _run_training (line 155) | def _run_training(self, config: PipelineConfig):

FILE: src/invoke_training/ui/utils/prompts.py
  function split_pos_neg_prompts (line 4) | def split_pos_neg_prompts(prompt: str) -> tuple[str, str]:
  function merge_pos_neg_prompts (line 28) | def merge_pos_neg_prompts(positive_prompt: str, negative_prompt: str) ->...
  function convert_ui_prompts_to_pos_neg_prompts (line 47) | def convert_ui_prompts_to_pos_neg_prompts(prompts: str) -> tuple[list[st...
  function convert_pos_neg_prompts_to_ui_prompts (line 69) | def convert_pos_neg_prompts_to_ui_prompts(positive_prompts: list[str], n...

FILE: src/invoke_training/ui/utils/utils.py
  function get_config_dir_path (line 10) | def get_config_dir_path() -> Path:
  function get_assets_dir_path (line 17) | def get_assets_dir_path() -> Path:
  function load_config_from_yaml (line 24) | def load_config_from_yaml(file_path: Path | str) -> PipelineConfig:
  function get_typing_literal_options (line 35) | def get_typing_literal_options(cls, field_name: str) -> list[str]:

FILE: tests/invoke_training/_shared/checkpoints/test_checkpoint_tracker.py
  function test_checkpoint_tracker_get_path_file (line 10) | def test_checkpoint_tracker_get_path_file():
  function test_checkpoint_tracker_get_path_directory (line 24) | def test_checkpoint_tracker_get_path_directory():
  function test_checkpoint_tracker_bad_extension (line 38) | def test_checkpoint_tracker_bad_extension():
  function test_checkpoint_tracker_prune_files (line 46) | def test_checkpoint_tracker_prune_files():
  function test_checkpoint_tracker_prune_directories (line 65) | def test_checkpoint_tracker_prune_directories():
  function test_checkpoint_tracker_prune_no_max (line 86) | def test_checkpoint_tracker_prune_no_max():

FILE: tests/invoke_training/_shared/checkpoints/test_serialization.py
  function test_state_dict_save_and_load_roundtrip (line 14) | def test_state_dict_save_and_load_roundtrip(file_name):
  function test_save_state_dict_bad_extension (line 29) | def test_save_state_dict_bad_extension():
  function test_load_state_dict_bad_extension (line 35) | def test_load_state_dict_bad_extension():

FILE: tests/invoke_training/_shared/data/data_loaders/test_dreambooth_sd_dataloader.py
  function test_build_dreambooth_sd_dataloader (line 12) | def test_build_dreambooth_sd_dataloader(image_dir):  # noqa: F811
  function test_build_dreambooth_sd_dataloader_no_class_dataset (line 48) | def test_build_dreambooth_sd_dataloader_no_class_dataset(image_dir):  # ...
  function test_build_dreambooth_sd_dataloader_with_bucketing (line 82) | def test_build_dreambooth_sd_dataloader_with_bucketing(image_dir):  # no...

FILE: tests/invoke_training/_shared/data/data_loaders/test_image_caption_sd_dataloader.py
  function test_build_image_caption_sd_dataloader (line 12) | def test_build_image_caption_sd_dataloader(image_caption_jsonl):  # noqa...
  function test_build_image_caption_sd_dataloader_with_masks (line 41) | def test_build_image_caption_sd_dataloader_with_masks(image_caption_json...

FILE: tests/invoke_training/_shared/data/data_loaders/test_image_pair_preference_sd_dataloader.py
  function test_build_image_pair_preference_sd_dataloader (line 16) | def test_build_image_pair_preference_sd_dataloader():

FILE: tests/invoke_training/_shared/data/data_loaders/test_textual_inversion_sd_dataloader.py
  function test_build_textual_inversion_sd_dataloader (line 15) | def test_build_textual_inversion_sd_dataloader(image_dir):  # noqa: F811
  function test_build_textual_inversion_sd_dataloader_keep_original_captions (line 50) | def test_build_textual_inversion_sd_dataloader_keep_original_captions(im...
  function test_build_textual_inversion_sd_dataloader_with_masks (line 72) | def test_build_textual_inversion_sd_dataloader_with_masks(image_caption_...

FILE: tests/invoke_training/_shared/data/dataset_fixtures.py
  function image_dir (line 10) | def image_dir(tmp_path_factory: pytest.TempPathFactory):
  function image_caption_dir (line 30) | def image_caption_dir(tmp_path_factory: pytest.TempPathFactory):
  function image_caption_jsonl (line 53) | def image_caption_jsonl(tmp_path_factory: pytest.TempPathFactory):
  function image_pair_preference_dir (line 88) | def image_pair_preference_dir(tmp_path_factory: pytest.TempPathFactory):

FILE: tests/invoke_training/_shared/data/datasets/test_hf_image_caption_dataset.py
  function create_hf_imagefolder_dataset (line 19) | def create_hf_imagefolder_dataset(tmp_dir: Path, num_images: int):
  function hf_imagefolder_dir (line 46) | def hf_imagefolder_dir(tmp_path_factory: pytest.TempPathFactory):
  function hf_dir_dataset (line 62) | def hf_dir_dataset(hf_imagefolder_dir: Path):
  function test_hf_dir_image_caption_dataset_bad_image_column (line 66) | def test_hf_dir_image_caption_dataset_bad_image_column(hf_imagefolder_di...
  function test_hf_dir_image_caption_dataset_bad_caption_column (line 74) | def test_hf_dir_image_caption_dataset_bad_caption_column(hf_imagefolder_...
  function test_hf_dir_image_caption_dataset_len (line 82) | def test_hf_dir_image_caption_dataset_len(hf_dir_dataset: HFImageCaption...
  function test_hf_dir_image_caption_dataset_index_error (line 87) | def test_hf_dir_image_caption_dataset_index_error(hf_dir_dataset: HFImag...
  function test_hf_dir_image_caption_dataset_getitem (line 93) | def test_hf_dir_image_caption_dataset_getitem(hf_dir_dataset: HFImageCap...
  function test_hf_dir_image_caption_dataset_get_image_dimensions (line 104) | def test_hf_dir_image_caption_dataset_get_image_dimensions(hf_dir_datase...
  function test_hf_hub_image_caption_dataset_bad_image_column (line 121) | def test_hf_hub_image_caption_dataset_bad_image_column():
  function test_hf_hub_image_caption_dataset_bad_caption_column (line 135) | def test_hf_hub_image_caption_dataset_bad_caption_column():
  function hf_hub_dataset (line 148) | def hf_hub_dataset():
  function test_hf_hub_image_caption_dataset_index_error (line 157) | def test_hf_hub_image_caption_dataset_index_error(hf_hub_dataset: HFImag...
  function test_hf_hub_image_caption_dataset_len (line 165) | def test_hf_hub_image_caption_dataset_len(hf_hub_dataset: HFImageCaption...
  function test_hf_hub_image_caption_dataset_getitem (line 174) | def test_hf_hub_image_caption_dataset_getitem(hf_hub_dataset: HFImageCap...
  function test_hf_hub_image_caption_dataset_get_image_dimensions (line 187) | def test_hf_hub_image_caption_dataset_get_image_dimensions(hf_hub_datase...

FILE: tests/invoke_training/_shared/data/datasets/test_hf_image_pair_preference_dataset.py
  function test_hf_hub_image_caption_dataset_getitem (line 9) | def test_hf_hub_image_caption_dataset_getitem():
  function test_hf_hub_image_caption_dataset_len (line 45) | def test_hf_hub_image_caption_dataset_len():
  function test_hf_hub_image_caption_dataset_skip_no_preference_len (line 66) | def test_hf_hub_image_caption_dataset_skip_no_preference_len():

FILE: tests/invoke_training/_shared/data/datasets/test_image_caption_dir_dataset.py
  function test_image_caption_dir_dataset_len (line 11) | def test_image_caption_dir_dataset_len(image_caption_dir):  # noqa: F811
  function test_image_caption_dir_dataset_getitem (line 17) | def test_image_caption_dir_dataset_getitem(image_caption_dir):  # noqa: ...
  function test_image_caption_dir_dataset_keep_in_memory (line 29) | def test_image_caption_dir_dataset_keep_in_memory(image_caption_dir):  #...
  function test_image_caption_dir_dataset_get_image_dimensions (line 41) | def test_image_caption_dir_dataset_get_image_dimensions(image_caption_di...
  function test_image_caption_dir_dataset_missing_caption_file (line 49) | def test_image_caption_dir_dataset_missing_caption_file(tmp_path: Path):...

FILE: tests/invoke_training/_shared/data/datasets/test_image_caption_jsonl_dataset.py
  function test_image_caption_jsonl_dataset_len (line 12) | def test_image_caption_jsonl_dataset_len(image_caption_jsonl):  # noqa: ...
  function test_image_caption_jsonl_dataset_getitem (line 18) | def test_image_caption_jsonl_dataset_getitem(image_caption_jsonl):  # no...
  function test_image_caption_jsonl_dataset_keep_in_memory (line 32) | def test_image_caption_jsonl_dataset_keep_in_memory(image_caption_jsonl)...
  function test_image_caption_jsonl_dataset_get_image_dimensions (line 53) | def test_image_caption_jsonl_dataset_get_image_dimensions(image_caption_...
  function test_image_caption_jsonl_dataset_save_jsonl (line 61) | def test_image_caption_jsonl_dataset_save_jsonl(image_caption_jsonl, tmp...

FILE: tests/invoke_training/_shared/data/datasets/test_image_dir_dataset.py
  function test_image_dir_dataset_len (line 8) | def test_image_dir_dataset_len(image_dir):  # noqa: F811
  function test_image_dir_dataset_getitem (line 14) | def test_image_dir_dataset_getitem(image_dir):  # noqa: F811
  function test_image_dir_dataset_keep_in_memory (line 25) | def test_image_dir_dataset_keep_in_memory(image_dir):  # noqa: F811
  function test_image_dir_dataset_get_image_dimensions (line 43) | def test_image_dir_dataset_get_image_dimensions(image_dir):  # noqa: F811

FILE: tests/invoke_training/_shared/data/datasets/test_image_pair_preference_dataset.py
  function test_image_dir_dataset_len (line 8) | def test_image_dir_dataset_len(image_pair_preference_dir):  # noqa: F811
  function test_image_dir_dataset_getitem (line 14) | def test_image_dir_dataset_getitem(image_pair_preference_dir):  # noqa: ...

FILE: tests/invoke_training/_shared/data/datasets/test_transform_dataset.py
  function test_transform_dataset_len (line 6) | def test_transform_dataset_len():
  function test_transform_dataset_getitem (line 16) | def test_transform_dataset_getitem():

FILE: tests/invoke_training/_shared/data/samplers/test_aspect_ratio_bucket_batch_sampler.py
  function assert_shuffled_samples_match (line 8) | def assert_shuffled_samples_match(samples_1, samples_2):
  function test_aspect_ratio_bucket_batch_sampler (line 18) | def test_aspect_ratio_bucket_batch_sampler():
  function test_aspect_ratio_bucket_batch_sampler_len (line 30) | def test_aspect_ratio_bucket_batch_sampler_len():
  function test_aspect_ratio_bucket_batch_sampler_from_image_sizes (line 42) | def test_aspect_ratio_bucket_batch_sampler_from_image_sizes():
  function test_aspect_ratio_bucket_batch_sampler_shuffle (line 66) | def test_aspect_ratio_bucket_batch_sampler_shuffle():
  function test_aspect_ratio_bucket_batch_sampler_seed (line 81) | def test_aspect_ratio_bucket_batch_sampler_seed():

FILE: tests/invoke_training/_shared/data/samplers/test_batch_offset_sampler.py
  function test_batch_offset_sampler (line 6) | def test_batch_offset_sampler():
  function test_batch_offset_sampler_len (line 18) | def test_batch_offset_sampler_len():

FILE: tests/invoke_training/_shared/data/samplers/test_concat_sampler.py
  function test_concat_sampler (line 4) | def test_concat_sampler():
  function test_concat_sampler_batches (line 16) | def test_concat_sampler_batches():
  function test_concat_sampler_len (line 28) | def test_concat_sampler_len():

FILE: tests/invoke_training/_shared/data/samplers/test_interleaved_sampler.py
  function test_interleaved_sampler (line 4) | def test_interleaved_sampler():
  function test_interleaved_sampler_batches (line 16) | def test_interleaved_sampler_batches():
  function test_interleaved_sampler_len (line 28) | def test_interleaved_sampler_len():

FILE: tests/invoke_training/_shared/data/samplers/test_offset_sampler.py
  function test_offset_sampler (line 6) | def test_offset_sampler():
  function test_offset_sampler_len (line 16) | def test_offset_sampler_len():

FILE: tests/invoke_training/_shared/data/transforms/test_caption_prefix_transform.py
  function test_caption_prefix_transform (line 4) | def test_caption_prefix_transform():

FILE: tests/invoke_training/_shared/data/transforms/test_concat_fields_transform.py
  function test_caption_prefix_transform (line 4) | def test_caption_prefix_transform():

FILE: tests/invoke_training/_shared/data/transforms/test_constant_field_transform.py
  function test_constant_field_transform (line 4) | def test_constant_field_transform():

FILE: tests/invoke_training/_shared/data/transforms/test_drop_field_transform.py
  function test_drop_field_transform (line 4) | def test_drop_field_transform():

FILE: tests/invoke_training/_shared/data/transforms/test_load_cache_transform.py
  function test_load_cache_transform (line 8) | def test_load_cache_transform():

FILE: tests/invoke_training/_shared/data/transforms/test_sd_image_transform.py
  function denormalize_image (line 13) | def denormalize_image(img: np.ndarray) -> np.ndarray:
  function denormalize_mask (line 32) | def denormalize_mask(mask: np.ndarray) -> np.ndarray:
  function test_sd_image_transform_resolution (line 41) | def test_sd_image_transform_resolution():
  function test_sd_image_transform_without_mask (line 69) | def test_sd_image_transform_without_mask():
  function test_sd_image_transform_range (line 92) | def test_sd_image_transform_range():
  function test_sd_image_transform_center_crop (line 124) | def test_sd_image_transform_center_crop():
  function test_sd_image_transform_random_crop (line 155) | def test_sd_image_transform_random_crop():
  function test_sd_image_transform_center_crop_flip (line 193) | def test_sd_image_transform_center_crop_flip():
  function test_sd_image_transform_random_crop_flip (line 228) | def test_sd_image_transform_random_crop_flip():
  function test_sd_image_transform_aspect_ratio_bucket_manager (line 270) | def test_sd_image_transform_aspect_ratio_bucket_manager():
  function test_sd_image_transform_resolution_input_validation (line 309) | def test_sd_image_transform_resolution_input_validation(

FILE: tests/invoke_training/_shared/data/transforms/test_shuffle_caption_transform.py
  function test_shuffle_caption_transform (line 4) | def test_shuffle_caption_transform():
  function test_shuffle_caption_transform_no_delimiter (line 15) | def test_shuffle_caption_transform_no_delimiter():

FILE: tests/invoke_training/_shared/data/transforms/test_template_caption_transform.py
  function test_template_caption_transform (line 8) | def test_template_caption_transform():
  function test_template_caption_transform_seed (line 20) | def test_template_caption_transform_seed():
  function test_template_caption_transform_bad_templates (line 55) | def test_template_caption_transform_bad_templates():

FILE: tests/invoke_training/_shared/data/transforms/test_tensor_disk_cache.py
  function test_tensor_disk_cache_roundtrip (line 9) | def test_tensor_disk_cache_roundtrip(tmp_path: Path):
  function test_tensor_disk_cache_fail_overwrite (line 26) | def test_tensor_disk_cache_fail_overwrite(tmp_path):

FILE: tests/invoke_training/_shared/data/utils/test_aspect_ratio_bucket_manager.py
  function test_build_aspect_ratio_buckets_input_validation (line 19) | def test_build_aspect_ratio_buckets_input_validation(
  function test_build_aspect_ratio_buckets (line 54) | def test_build_aspect_ratio_buckets(
  function test_get_aspect_ratio_bucket (line 79) | def test_get_aspect_ratio_bucket(resolution: Resolution, expected_bucket...

FILE: tests/invoke_training/_shared/data/utils/test_resize.py
  function test_resize_to_cover (line 28) | def test_resize_to_cover(in_resolution: Resolution, size_to_cover: Resol...

FILE: tests/invoke_training/_shared/data/utils/test_resolution.py
  function test_resolution_parse (line 14) | def test_resolution_parse(input, expected_resolution: Resolution):

FILE: tests/invoke_training/_shared/stable_diffusion/test_base_model_version.py
  function test_get_base_model_version (line 21) | def test_get_base_model_version(diffusers_model_name: str, expected_vers...
  function test_check_base_model_version_pass (line 39) | def test_check_base_model_version_pass():
  function test_check_base_model_version_fail (line 45) | def test_check_base_model_version_fail():

FILE: tests/invoke_training/_shared/stable_diffusion/test_lora_checkpoint_utils.py
  function test_convert_sd_peft_checkpoint_to_kohya_state_dict_raise_on_empty_directory (line 10) | def test_convert_sd_peft_checkpoint_to_kohya_state_dict_raise_on_empty_d...
  function test_convert_sd_peft_checkpoint_to_kohya_state_dict_raise_on_unexpected_subdirectory (line 17) | def test_convert_sd_peft_checkpoint_to_kohya_state_dict_raise_on_unexpec...

FILE: tests/invoke_training/_shared/stable_diffusion/test_model_loading_utils.py
  function test_load_models_sd (line 17) | def test_load_models_sd(sdv1_embedding_path):  # noqa: F811
  function test_load_models_sdxl (line 38) | def test_load_models_sdxl(sdxl_embedding_path: Path):  # noqa: F811

FILE: tests/invoke_training/_shared/stable_diffusion/test_textual_inversion.py
  function test_expand_placeholder_token (line 22) | def test_expand_placeholder_token(placeholder_token: str, num_vectors: i...
  function test_expand_placeholder_token_raises_on_invalid_num_vectors (line 26) | def test_expand_placeholder_token_raises_on_invalid_num_vectors():
  function test_initialize_placeholder_tokens_from_initializer_token (line 32) | def test_initialize_placeholder_tokens_from_initializer_token():
  function test_initialize_placeholder_tokens_from_initial_phrase (line 60) | def test_initialize_placeholder_tokens_from_initial_phrase():
  function test_initialize_placeholder_tokens_from_initial_embedding (line 87) | def test_initialize_placeholder_tokens_from_initial_embedding(sdv1_embed...

FILE: tests/invoke_training/_shared/stable_diffusion/ti_embedding_checkpoint_fixture.py
  function sdv1_embedding_path (line 8) | def sdv1_embedding_path(tmp_path_factory: pytest.TempPathFactory):
  function sdxl_embedding_path (line 26) | def sdxl_embedding_path(tmp_path_factory: pytest.TempPathFactory):

FILE: tests/invoke_training/_shared/utils/test_jsonl.py
  function test_jsonl_roundtrip (line 6) | def test_jsonl_roundtrip(tmp_path: Path):

FILE: tests/invoke_training/config/pipelines/test_pipeline_config.py
  function test_pipeline_config (line 10) | def test_pipeline_config():

FILE: tests/invoke_training/model_merge/test_merge_models.py
  function test_merge_models_raises_on_not_enough_state_dicts (line 12) | def test_merge_models_raises_on_not_enough_state_dicts():
  function test_merge_models_raises_on_mismatched_weights (line 17) | def test_merge_models_raises_on_mismatched_weights():
  function test_merge_models (line 88) | def test_merge_models(

FILE: tests/invoke_training/model_merge/test_merge_tasks_to_base.py
  function test_merge_raises_on_mismatched_weights (line 11) | def test_merge_raises_on_mismatched_weights():
  function test_merge_ties (line 65) | def test_merge_ties(

FILE: tests/invoke_training/model_merge/utils.py
  function state_dicts_are_close (line 4) | def state_dicts_are_close(a: dict[str, torch.Tensor], b: dict[str, torch...

FILE: tests/invoke_training/ui/utils/test_prompts.py
  function test_split_pos_neg_prompts (line 21) | def test_split_pos_neg_prompts(prompt: str, expected_positive_prompt: st...
  function test_split_pos_neg_prompts_raises_value_error (line 34) | def test_split_pos_neg_prompts_raises_value_error(prompt: str):
  function test_convert_ui_prompts_to_pos_neg_prompts (line 73) | def test_convert_ui_prompts_to_pos_neg_prompts(
  function test_convert_pos_neg_prompts_to_ui_prompts (line 82) | def test_convert_pos_neg_prompts_to_ui_prompts(
Condensed preview — 246 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (1,009K chars).
[
  {
    "path": ".github/workflows/deploy.yaml",
    "chars": 954,
    "preview": "name: Deploy invoke-training docs\n\non:\n  push:\n    branches:\n      - main\n\npermissions:\n  contents: write\n\njobs:\n  deplo"
  },
  {
    "path": ".github/workflows/test.yaml",
    "chars": 1270,
    "preview": "name: Test invoke-training\n\non:\n  push:\n    branches:\n      - main\n  pull_request:\n  workflow_dispatch:\n\njobs:\n  build:\n"
  },
  {
    "path": ".gitignore",
    "chars": 3169,
    "preview": "/output/\n/test_configs/\n/data/\n\n# pyenv\n.python-version\n\n# VSCode\n.vscode/\n\n# Byte-compiled / optimized / DLL files\n__py"
  },
  {
    "path": ".pre-commit-config.yaml",
    "chars": 237,
    "preview": "# See https://pre-commit.com/ for usage and config.\nrepos:\n- repo: https://github.com/astral-sh/ruff-pre-commit\n  # Ruff"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 2179,
    "preview": "# invoke-training\n\nA library for training custom Stable Diffusion models (fine-tuning, LoRA training, textual inversion,"
  },
  {
    "path": "docs/contributing/development_environment.md",
    "chars": 135,
    "preview": "# Development Environment Setup\n\nSee the [developer installation instructions](../get-started/installation.md#developer-"
  },
  {
    "path": "docs/contributing/directory_structure.md",
    "chars": 706,
    "preview": "# Directory Structure\n\n```bash\ninvoke-training/\n├── README.md\n├── docs/\n├── src/\n│   └── invoke-training/\n│       ├── _s"
  },
  {
    "path": "docs/contributing/documentation.md",
    "chars": 225,
    "preview": "# Documentation\n\nThe documentation site is generated using [mkdocs](https://www.mkdocs.org/) and [mkdocstrings-python](h"
  },
  {
    "path": "docs/contributing/tests.md",
    "chars": 393,
    "preview": "# Tests\n\nRun all unit tests with:\n\n```bash\npytest tests/\n```\n\nThere are some test 'markers' defined in [pyproject.toml]("
  },
  {
    "path": "docs/get-started/installation.md",
    "chars": 2348,
    "preview": "# Installation\n\n## Requirements\n\n1. Python 3.10, 3.11 and 3.12 are currently supported. Check your Python version by run"
  },
  {
    "path": "docs/get-started/quick-start.md",
    "chars": 3352,
    "preview": "# Quick Start\n\n`invoke-training` has both a GUI and a CLI (for advanced users). The instructions for getting started wit"
  },
  {
    "path": "docs/guides/dataset_formats.md",
    "chars": 3778,
    "preview": "# Dataset Formats\n\n`invoke-training` supports the following dataset formats:\n\n- `IMAGE_CAPTION_JSONL_DATASET`: A local i"
  },
  {
    "path": "docs/guides/model_merge.md",
    "chars": 1824,
    "preview": "# Model Merging\n\n`invoke-training` provides utility scripts for several common model merging workflows. This page contai"
  },
  {
    "path": "docs/guides/stable_diffusion/dpo_lora_sd.md",
    "chars": 3900,
    "preview": "# (Experimental) Diffusion DPO - SD\n\n!!! tip \"Experimental\"\n    The Diffusion Direct Preference Optimization training pi"
  },
  {
    "path": "docs/guides/stable_diffusion/gnome_lora_masks_sdxl.md",
    "chars": 4972,
    "preview": "# LoRA with Masks - SDXL\n\nThis tutorial explains how to prepare masks for an image dataset and then use that dataset to "
  },
  {
    "path": "docs/guides/stable_diffusion/robocats_finetune_sdxl.md",
    "chars": 7415,
    "preview": "# Finetune - SDXL\n\nThis tutorial explains how to do a full finetune training run on a [Stable Diffusion XL](https://hugg"
  },
  {
    "path": "docs/guides/stable_diffusion/textual_inversion_sdxl.md",
    "chars": 3804,
    "preview": "# Textual Inversion - SDXL\n\nThis tutorial walks through a [Textual Inversion](https://arxiv.org/abs/2208.01618) training"
  },
  {
    "path": "docs/index.md",
    "chars": 693,
    "preview": "# invoke-training\n\nA library for training custom Stable Diffusion models (fine-tuning, LoRA training, textual inversion,"
  },
  {
    "path": "docs/reference/config/index.md",
    "chars": 373,
    "preview": "# Config Reference\n\nThis section contains reference documentation for the `invoke-training` configuration schema (i.e. d"
  },
  {
    "path": "docs/reference/config/pipelines/sd_lora.md",
    "chars": 478,
    "preview": "# `SdLoraConfig`\n\n<!-- To control the member order, we first list out the members whose order we care about, then we lis"
  },
  {
    "path": "docs/reference/config/pipelines/sd_textual_inversion.md",
    "chars": 541,
    "preview": "# `SdTextualInversionConfig`\n\n<!-- To control the member order, we first list out the members whose order we care about,"
  },
  {
    "path": "docs/reference/config/pipelines/sdxl_finetune.md",
    "chars": 511,
    "preview": "# `SdxlFinetuneConfig`\n\n<!-- To control the member order, we first list out the members whose order we care about, then "
  },
  {
    "path": "docs/reference/config/pipelines/sdxl_lora.md",
    "chars": 490,
    "preview": "# `SdxlLoraConfig`\n\n<!-- To control the member order, we first list out the members whose order we care about, then we l"
  },
  {
    "path": "docs/reference/config/pipelines/sdxl_lora_and_textual_inversion.md",
    "chars": 591,
    "preview": "# `SdxlLoraAndTextualInversionConfig`\n\n<!-- To control the member order, we first list out the members whose order we ca"
  },
  {
    "path": "docs/reference/config/pipelines/sdxl_textual_inversion.md",
    "chars": 988,
    "preview": "# `SdxlTextualInversionConfig`\n\nBelow is a sample yaml config file for Textual Inversion SDXL training ([raw file](https"
  },
  {
    "path": "docs/reference/config/shared/data/data_loader_config.md",
    "chars": 104,
    "preview": "::: invoke_training.config.data.data_loader_config\n    options:\n      filters:\n      - \"!^model_config\"\n"
  },
  {
    "path": "docs/reference/config/shared/data/dataset_config.md",
    "chars": 99,
    "preview": "::: invoke_training.config.data.dataset_config\n    options:\n      filters:\n      - \"!^model_config\""
  },
  {
    "path": "docs/reference/config/shared/optimizer_config.md",
    "chars": 107,
    "preview": "::: invoke_training.config.optimizer.optimizer_config\n    options:\n      filters:\n      - \"!^model_config\"\n"
  },
  {
    "path": "docs/templates/python/material/labels.html",
    "chars": 266,
    "preview": "<!--\n    This file is intentionally empty. It overrides the default contents of\n    https://github.com/mkdocstrings/pyth"
  },
  {
    "path": "mkdocs.yml",
    "chars": 2583,
    "preview": "site_name: invoke-training\nsite_url: https://invoke-ai.github.io/invoke-training/\n\nrepo_name: invoke-ai/invoke-training\n"
  },
  {
    "path": "pyproject.toml",
    "chars": 2104,
    "preview": "[build-system]\nrequires = [\"setuptools>=65.5\", \"pip>=22.3\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"i"
  },
  {
    "path": "sample_data/bruce_the_gnome/data.jsonl",
    "chars": 451,
    "preview": "{\"image\": \"001.png\", \"text\": \"A stuffed gnome sits on a wooden floor, facing right with a gray couch in the background.\""
  },
  {
    "path": "src/invoke_training/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/invoke_training/_shared/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/invoke_training/_shared/accelerator/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/invoke_training/_shared/accelerator/accelerator_utils.py",
    "chars": 3573,
    "preview": "import logging\nimport os\nfrom typing import Literal\n\nimport datasets\nimport diffusers\nimport torch\nimport transformers\nf"
  },
  {
    "path": "src/invoke_training/_shared/checkpoints/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/invoke_training/_shared/checkpoints/checkpoint_tracker.py",
    "chars": 3876,
    "preview": "import os\nimport shutil\nimport typing\n\n\nclass CheckpointTracker:\n    \"\"\"A utility class for managing checkpoint paths.\n\n"
  },
  {
    "path": "src/invoke_training/_shared/checkpoints/lora_checkpoint_utils.py",
    "chars": 3805,
    "preview": "from pathlib import Path\n\nimport peft\nimport torch\n\n\ndef save_multi_model_peft_checkpoint(checkpoint_dir: Path | str, mo"
  },
  {
    "path": "src/invoke_training/_shared/checkpoints/serialization.py",
    "chars": 1883,
    "preview": "import typing\nfrom pathlib import Path\n\nimport safetensors.torch\nimport torch\n\n\ndef save_state_dict(state_dict: typing.D"
  },
  {
    "path": "src/invoke_training/_shared/data/ARCHITECTURE.md",
    "chars": 1284,
    "preview": "# Dataset Architecture\nDataset handling is split into 3 layers of abstraction: Datasets, Transforms, and DataLoaders. Ea"
  },
  {
    "path": "src/invoke_training/_shared/data/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/invoke_training/_shared/data/data_loaders/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/invoke_training/_shared/data/data_loaders/dreambooth_sd_dataloader.py",
    "chars": 8355,
    "preview": "import typing\n\nfrom torch.utils.data import ConcatDataset, DataLoader\nfrom torch.utils.data.sampler import RandomSampler"
  },
  {
    "path": "src/invoke_training/_shared/data/data_loaders/image_caption_flux_dataloader.py",
    "chars": 6423,
    "preview": "import typing\n\nfrom torch.utils.data import DataLoader\n\nfrom invoke_training._shared.data.data_loaders.image_caption_sd_"
  },
  {
    "path": "src/invoke_training/_shared/data/data_loaders/image_caption_sd_dataloader.py",
    "chars": 8199,
    "preview": "import typing\n\nimport torch\nfrom torch.utils.data import DataLoader\n\nfrom invoke_training._shared.data.datasets.build_da"
  },
  {
    "path": "src/invoke_training/_shared/data/data_loaders/image_pair_preference_sd_dataloader.py",
    "chars": 5973,
    "preview": "import typing\n\nimport torch\nfrom torch.utils.data import DataLoader\n\nfrom invoke_training._shared.data.datasets.build_da"
  },
  {
    "path": "src/invoke_training/_shared/data/data_loaders/textual_inversion_sd_dataloader.py",
    "chars": 10371,
    "preview": "from typing import Literal, Optional\n\nfrom torch.utils.data import DataLoader\n\nfrom invoke_training._shared.data.data_lo"
  },
  {
    "path": "src/invoke_training/_shared/data/datasets/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/invoke_training/_shared/data/datasets/build_dataset.py",
    "chars": 2930,
    "preview": "from datasets import VerificationMode\n\nfrom invoke_training._shared.data.datasets.hf_image_caption_dataset import HFImag"
  },
  {
    "path": "src/invoke_training/_shared/data/datasets/hf_image_caption_dataset.py",
    "chars": 5046,
    "preview": "import os\nimport typing\n\nimport datasets\nimport torch.utils.data\nfrom PIL.Image import Image\n\nfrom invoke_training._shar"
  },
  {
    "path": "src/invoke_training/_shared/data/datasets/hf_image_pair_preference_dataset.py",
    "chars": 4649,
    "preview": "import io\nimport typing\n\nimport datasets\nimport torch.utils.data\nfrom PIL import Image\n\n\nclass HFImagePairPreferenceData"
  },
  {
    "path": "src/invoke_training/_shared/data/datasets/image_caption_dir_dataset.py",
    "chars": 3753,
    "preview": "import os\nimport typing\n\nimport torch.utils.data\nfrom PIL import Image\n\nfrom invoke_training._shared.data.utils.resoluti"
  },
  {
    "path": "src/invoke_training/_shared/data/datasets/image_caption_jsonl_dataset.py",
    "chars": 4708,
    "preview": "import typing\nfrom pathlib import Path\n\nimport torch.utils.data\nfrom PIL import Image\nfrom pydantic import BaseModel\n\nfr"
  },
  {
    "path": "src/invoke_training/_shared/data/datasets/image_dir_dataset.py",
    "chars": 2883,
    "preview": "import os\nimport typing\n\nimport torch.utils.data\nfrom PIL import Image\n\nfrom invoke_training._shared.data.utils.resoluti"
  },
  {
    "path": "src/invoke_training/_shared/data/datasets/image_pair_preference_dataset.py",
    "chars": 1549,
    "preview": "import os\nimport typing\nfrom pathlib import Path\n\nimport torch.utils.data\nfrom PIL import Image\n\nfrom invoke_training._s"
  },
  {
    "path": "src/invoke_training/_shared/data/datasets/transform_dataset.py",
    "chars": 834,
    "preview": "import typing\n\nimport torch.utils.data\n\n# The data type expected to be produced by the base dataset and handled by trans"
  },
  {
    "path": "src/invoke_training/_shared/data/samplers/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/invoke_training/_shared/data/samplers/aspect_ratio_bucket_batch_sampler.py",
    "chars": 4341,
    "preview": "import copy\nimport logging\nimport math\nimport random\nfrom typing import Iterator\n\nfrom torch.utils.data import Sampler\n\n"
  },
  {
    "path": "src/invoke_training/_shared/data/samplers/batch_offset_sampler.py",
    "chars": 560,
    "preview": "import typing\n\nfrom torch.utils.data import Sampler\n\n\nclass BatchOffsetSampler(Sampler[int]):\n    \"\"\"A sampler that wrap"
  },
  {
    "path": "src/invoke_training/_shared/data/samplers/concat_sampler.py",
    "chars": 685,
    "preview": "import itertools\nimport typing\n\nfrom torch.utils.data import Sampler\n\nT_co = typing.TypeVar(\"T_co\", covariant=True)\n\n\ncl"
  },
  {
    "path": "src/invoke_training/_shared/data/samplers/interleaved_sampler.py",
    "chars": 1264,
    "preview": "import typing\n\nfrom torch.utils.data import Sampler\n\nT_co = typing.TypeVar(\"T_co\", covariant=True)\n\n\nclass InterleavedSa"
  },
  {
    "path": "src/invoke_training/_shared/data/samplers/offset_sampler.py",
    "chars": 490,
    "preview": "import typing\n\nfrom torch.utils.data import Sampler\n\n\nclass OffsetSampler(Sampler[int]):\n    \"\"\"A sampler that wraps ano"
  },
  {
    "path": "src/invoke_training/_shared/data/transforms/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/invoke_training/_shared/data/transforms/caption_prefix_transform.py",
    "chars": 459,
    "preview": "import typing\n\n\nclass CaptionPrefixTransform:\n    \"\"\"A transform that adds a prefix to all example captions.\"\"\"\n\n    def"
  },
  {
    "path": "src/invoke_training/_shared/data/transforms/concat_fields_transform.py",
    "chars": 589,
    "preview": "import typing\n\n\nclass ConcatFieldsTransform:\n    \"\"\"A transform that concatenate multiple string fields.\"\"\"\n\n    def __i"
  },
  {
    "path": "src/invoke_training/_shared/data/transforms/constant_field_transform.py",
    "chars": 429,
    "preview": "import typing\n\n\nclass ConstantFieldTransform:\n    \"\"\"A simple transform that adds a constant field to every example.\"\"\"\n"
  },
  {
    "path": "src/invoke_training/_shared/data/transforms/drop_field_transform.py",
    "chars": 391,
    "preview": "import typing\n\n\nclass DropFieldTransform:\n    \"\"\"A simple transform that drops a field from an example.\"\"\"\n\n    def __in"
  },
  {
    "path": "src/invoke_training/_shared/data/transforms/flux_image_transform.py",
    "chars": 4527,
    "preview": "import typing\n\nfrom torchvision import transforms\nfrom torchvision.transforms.functional import crop\n\nfrom invoke_traini"
  },
  {
    "path": "src/invoke_training/_shared/data/transforms/load_cache_transform.py",
    "chars": 1187,
    "preview": "import typing\n\nfrom invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache\n\n\nclass LoadCacheTr"
  },
  {
    "path": "src/invoke_training/_shared/data/transforms/sd_image_transform.py",
    "chars": 6214,
    "preview": "import random\nimport typing\n\nfrom torchvision import transforms\nfrom torchvision.transforms.functional import crop\n\nfrom"
  },
  {
    "path": "src/invoke_training/_shared/data/transforms/shuffle_caption_transform.py",
    "chars": 955,
    "preview": "import typing\n\nimport numpy as np\n\n\nclass ShuffleCaptionTransform:\n    \"\"\"A transform that applies shuffle transformatio"
  },
  {
    "path": "src/invoke_training/_shared/data/transforms/template_caption_transform.py",
    "chars": 908,
    "preview": "import typing\n\nimport numpy as np\n\n\nclass TemplateCaptionTransform:\n    \"\"\"A simple transform that constructs a caption "
  },
  {
    "path": "src/invoke_training/_shared/data/transforms/tensor_disk_cache.py",
    "chars": 1525,
    "preview": "import os\nimport typing\n\nimport torch\n\n\nclass TensorDiskCache:\n    \"\"\"A data cache that caches `torch.Tensor`s on disk.\""
  },
  {
    "path": "src/invoke_training/_shared/data/utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/invoke_training/_shared/data/utils/aspect_ratio_bucket_manager.py",
    "chars": 2228,
    "preview": "from invoke_training._shared.data.utils.resolution import Resolution\n\n\nclass AspectRatioBucketManager:\n    def __init__("
  },
  {
    "path": "src/invoke_training/_shared/data/utils/resize.py",
    "chars": 1148,
    "preview": "import math\n\nfrom PIL.Image import Image\nfrom torchvision import transforms\n\nfrom invoke_training._shared.data.utils.res"
  },
  {
    "path": "src/invoke_training/_shared/data/utils/resolution.py",
    "chars": 1176,
    "preview": "from typing import Union\n\n\nclass Resolution:\n    def __init__(self, height: int, width: int):\n        self.height = heig"
  },
  {
    "path": "src/invoke_training/_shared/flux/encoding_utils.py",
    "chars": 7754,
    "preview": "import logging\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nfrom transformers import CLIPTextModel, CLI"
  },
  {
    "path": "src/invoke_training/_shared/flux/lora_checkpoint_utils.py",
    "chars": 19688,
    "preview": "# ruff: noqa: N806\nimport os\nfrom pathlib import Path\n\nimport peft\nimport torch\nfrom diffusers import FluxTransformer2DM"
  },
  {
    "path": "src/invoke_training/_shared/flux/model_loading_utils.py",
    "chars": 6226,
    "preview": "import logging\nfrom enum import Enum\n\nimport torch\nfrom diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler,"
  },
  {
    "path": "src/invoke_training/_shared/flux/validation.py",
    "chars": 5099,
    "preview": "import logging\nimport os\n\nimport numpy as np\nimport torch\nimport torch.utils.data\nfrom accelerate import Accelerator\nfro"
  },
  {
    "path": "src/invoke_training/_shared/optimizer/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/invoke_training/_shared/optimizer/optimizer_utils.py",
    "chars": 1513,
    "preview": "import torch\nfrom prodigyopt import Prodigy\n\nfrom invoke_training.config.optimizer.optimizer_config import AdamOptimizer"
  },
  {
    "path": "src/invoke_training/_shared/stable_diffusion/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/invoke_training/_shared/stable_diffusion/base_model_version.py",
    "chars": 2986,
    "preview": "from enum import Enum\n\nfrom transformers import PretrainedConfig\n\n\nclass BaseModelVersionEnum(Enum):\n    STABLE_DIFFUSIO"
  },
  {
    "path": "src/invoke_training/_shared/stable_diffusion/checkpoint_utils.py",
    "chars": 2264,
    "preview": "from pathlib import Path\n\nimport torch\nfrom diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, UN"
  },
  {
    "path": "src/invoke_training/_shared/stable_diffusion/lora_checkpoint_utils.py",
    "chars": 7415,
    "preview": "import os\nfrom pathlib import Path\n\nimport peft\nimport torch\nfrom diffusers import UNet2DConditionModel\nfrom transformer"
  },
  {
    "path": "src/invoke_training/_shared/stable_diffusion/min_snr_weighting.py",
    "chars": 1534,
    "preview": "import torch\nfrom diffusers import DDPMScheduler\n\n\ndef compute_snr(noise_scheduler: DDPMScheduler, timesteps: torch.Tens"
  },
  {
    "path": "src/invoke_training/_shared/stable_diffusion/model_loading_utils.py",
    "chars": 8016,
    "preview": "import logging\nimport os\nimport typing\nfrom enum import Enum\n\nimport torch\nfrom diffusers import (\n    AutoencoderKL,\n  "
  },
  {
    "path": "src/invoke_training/_shared/stable_diffusion/textual_inversion.py",
    "chars": 9269,
    "preview": "import logging\n\nimport torch\nfrom accelerate import Accelerator\nfrom transformers import CLIPTextModel, CLIPTokenizer, P"
  },
  {
    "path": "src/invoke_training/_shared/stable_diffusion/tokenize_captions.py",
    "chars": 907,
    "preview": "import torch\nfrom transformers import CLIPTokenizer\n\nfrom invoke_training._shared.stable_diffusion.textual_inversion imp"
  },
  {
    "path": "src/invoke_training/_shared/stable_diffusion/validation.py",
    "chars": 10055,
    "preview": "import logging\nimport os\n\nimport numpy as np\nimport torch\nimport torch.utils.data\nfrom accelerate import Accelerator\nfro"
  },
  {
    "path": "src/invoke_training/_shared/tools/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/invoke_training/_shared/tools/generate_images.py",
    "chars": 4326,
    "preview": "import os\nfrom pathlib import Path\nfrom typing import Optional\n\nimport torch\nfrom tqdm import tqdm\n\nfrom invoke_training"
  },
  {
    "path": "src/invoke_training/_shared/utils/import_xformers.py",
    "chars": 291,
    "preview": "def import_xformers():\n    try:\n        import xformers  # noqa: F401\n    except ImportError:\n        raise ImportError("
  },
  {
    "path": "src/invoke_training/_shared/utils/jsonl.py",
    "chars": 525,
    "preview": "import json\nfrom pathlib import Path\nfrom typing import Any\n\n\ndef load_jsonl(jsonl_path: Path | str) -> list[Any]:\n    \""
  },
  {
    "path": "src/invoke_training/config/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/invoke_training/config/base_pipeline_config.py",
    "chars": 2184,
    "preview": "import typing\nfrom typing import Optional\n\nfrom invoke_training.config.config_base_model import ConfigBaseModel\n\n\nclass "
  },
  {
    "path": "src/invoke_training/config/config_base_model.py",
    "chars": 249,
    "preview": "from pydantic import BaseModel, ConfigDict\n\n\nclass ConfigBaseModel(BaseModel):\n    \"\"\"Base model for all invoke training"
  },
  {
    "path": "src/invoke_training/config/data/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/invoke_training/config/data/data_loader_config.py",
    "chars": 7564,
    "preview": "from typing import Literal, Optional\n\nfrom invoke_training.config.config_base_model import ConfigBaseModel\nfrom invoke_t"
  },
  {
    "path": "src/invoke_training/config/data/dataset_config.py",
    "chars": 2917,
    "preview": "from typing import Annotated, Literal, Optional, Union\n\nfrom pydantic import Field\n\nfrom invoke_training.config.config_b"
  },
  {
    "path": "src/invoke_training/config/optimizer/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/invoke_training/config/optimizer/optimizer_config.py",
    "chars": 1500,
    "preview": "import typing\n\nfrom invoke_training.config.config_base_model import ConfigBaseModel\n\n\nclass AdamOptimizerConfig(ConfigBa"
  },
  {
    "path": "src/invoke_training/config/pipeline_config.py",
    "chars": 1198,
    "preview": "from typing import Annotated, Union\n\nfrom pydantic import Field\n\nfrom invoke_training.pipelines._experimental.sd_dpo_lor"
  },
  {
    "path": "src/invoke_training/model_merge/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/invoke_training/model_merge/extract_lora.py",
    "chars": 3642,
    "preview": "import torch\nimport tqdm\nfrom peft.peft_model import PeftModel\n\n# All original base model weights in a PeftModel have th"
  },
  {
    "path": "src/invoke_training/model_merge/merge_models.py",
    "chars": 3578,
    "preview": "from typing import Literal\n\nimport torch\nimport tqdm\n\nfrom invoke_training.model_merge.utils.normalize_weights import no"
  },
  {
    "path": "src/invoke_training/model_merge/merge_tasks_to_base.py",
    "chars": 2929,
    "preview": "from typing import Literal\n\nimport torch\nimport tqdm\nfrom peft.utils.merge_utils import dare_linear, dare_ties, ties\n\n\n@"
  },
  {
    "path": "src/invoke_training/model_merge/scripts/extract_lora_from_model_diff.py",
    "chars": 13509,
    "preview": "# This script is based on\n# https://raw.githubusercontent.com/kohya-ss/sd-scripts/bfb352bc433326a77aca3124248331eb60c49e"
  },
  {
    "path": "src/invoke_training/model_merge/scripts/merge_lora_into_model.py",
    "chars": 7109,
    "preview": "import argparse  # noqa: I001\nimport logging\nfrom pathlib import Path\n\nimport torch\nfrom diffusers import StableDiffusio"
  },
  {
    "path": "src/invoke_training/model_merge/scripts/merge_models.py",
    "chars": 5432,
    "preview": "import argparse\nimport logging\nfrom dataclasses import dataclass\nfrom pathlib import Path\n\nimport torch\nfrom diffusers i"
  },
  {
    "path": "src/invoke_training/model_merge/scripts/merge_task_models_to_base_model.py",
    "chars": 6717,
    "preview": "import argparse\nimport logging\nfrom pathlib import Path\n\nimport torch\nfrom diffusers import StableDiffusionPipeline, Sta"
  },
  {
    "path": "src/invoke_training/model_merge/utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/invoke_training/model_merge/utils/normalize_weights.py",
    "chars": 135,
    "preview": "def normalize_weights(weights: list[float]) -> list[float]:\n    total = sum(weights)\n    return [weight / total for weig"
  },
  {
    "path": "src/invoke_training/model_merge/utils/parse_model_arg.py",
    "chars": 378,
    "preview": "def parse_model_arg(model: str, delimiter: str = \"::\") -> tuple[str, str | None]:\n    \"\"\"Parse a model argument into a m"
  },
  {
    "path": "src/invoke_training/pipelines/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/invoke_training/pipelines/_experimental/sd_dpo_lora/config.py",
    "chars": 11403,
    "preview": "from typing import Annotated, Literal, Union\n\nfrom pydantic import Field, model_validator\n\nfrom invoke_training.config.b"
  },
  {
    "path": "src/invoke_training/pipelines/_experimental/sd_dpo_lora/train.py",
    "chars": 27220,
    "preview": "import copy\nimport itertools\nimport json\nimport logging\nimport math\nimport os\nimport tempfile\nimport time\nfrom pathlib i"
  },
  {
    "path": "src/invoke_training/pipelines/callbacks.py",
    "chars": 3289,
    "preview": "from abc import ABC\nfrom enum import Enum\n\n\nclass ModelType(Enum):\n    # At first glance, it feels like these model type"
  },
  {
    "path": "src/invoke_training/pipelines/flux/lora/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/invoke_training/pipelines/flux/lora/config.py",
    "chars": 10288,
    "preview": "from typing import Annotated, Literal, Union\n\nfrom pydantic import Field\n\nfrom invoke_training._shared.flux.lora_checkpo"
  },
  {
    "path": "src/invoke_training/pipelines/flux/lora/train.py",
    "chars": 30814,
    "preview": "import itertools\nimport json\nimport logging\nimport math\nimport os\nimport tempfile\nimport time\nfrom pathlib import Path\nf"
  },
  {
    "path": "src/invoke_training/pipelines/invoke_train.py",
    "chars": 2288,
    "preview": "import os\n\nfrom invoke_training.config.pipeline_config import PipelineConfig\nfrom invoke_training.pipelines._experimenta"
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion/lora/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion/lora/config.py",
    "chars": 10750,
    "preview": "from typing import Annotated, Literal, Union\n\nfrom pydantic import Field, model_validator\n\nfrom invoke_training._shared."
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion/lora/train.py",
    "chars": 29952,
    "preview": "import itertools\nimport json\nimport logging\nimport math\nimport os\nimport tempfile\nimport time\nfrom pathlib import Path\nf"
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion/textual_inversion/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion/textual_inversion/config.py",
    "chars": 9523,
    "preview": "from typing import Literal\n\nfrom pydantic import model_validator\n\nfrom invoke_training.config.base_pipeline_config impor"
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion/textual_inversion/train.py",
    "chars": 19996,
    "preview": "import json\nimport logging\nimport math\nimport os\nimport tempfile\nimport time\n\nimport torch\nfrom accelerate import Accele"
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion_xl/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion_xl/finetune/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion_xl/finetune/config.py",
    "chars": 8491,
    "preview": "from typing import Annotated, Literal, Union\n\nfrom pydantic import Field, model_validator\n\nfrom invoke_training.config.b"
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion_xl/finetune/train.py",
    "chars": 20254,
    "preview": "import itertools\nimport json\nimport logging\nimport math\nimport os\nimport tempfile\nimport time\nfrom typing import Literal"
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion_xl/lora/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion_xl/lora/config.py",
    "chars": 11150,
    "preview": "from typing import Annotated, Literal, Union\n\nfrom pydantic import Field, model_validator\n\nfrom invoke_training._shared."
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion_xl/lora/train.py",
    "chars": 34092,
    "preview": "import itertools\nimport json\nimport logging\nimport math\nimport os\nimport tempfile\nimport time\nfrom pathlib import Path\nf"
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion_xl/lora_and_textual_inversion/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion_xl/lora_and_textual_inversion/config.py",
    "chars": 11458,
    "preview": "from typing import Literal\n\nfrom pydantic import model_validator\n\nfrom invoke_training.config.base_pipeline_config impor"
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion_xl/lora_and_textual_inversion/train.py",
    "chars": 25664,
    "preview": "import itertools\nimport json\nimport logging\nimport math\nimport os\nimport time\nfrom pathlib import Path\nfrom typing impor"
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion_xl/textual_inversion/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion_xl/textual_inversion/config.py",
    "chars": 9751,
    "preview": "from typing import Literal\n\nfrom pydantic import model_validator\n\nfrom invoke_training.config.base_pipeline_config impor"
  },
  {
    "path": "src/invoke_training/pipelines/stable_diffusion_xl/textual_inversion/train.py",
    "chars": 22955,
    "preview": "import json\nimport logging\nimport math\nimport os\nimport tempfile\nimport time\n\nimport torch\nimport torch.utils.data\nfrom "
  },
  {
    "path": "src/invoke_training/sample_configs/_experimental/sd_dpo_lora_pickapic_1x24gb.yaml",
    "chars": 1141,
    "preview": "# Training mode: Direct Preference Optimization LoRA Training\n# Dataset: A small subset of the pickapic_v2 dataset.\n# Ba"
  },
  {
    "path": "src/invoke_training/sample_configs/_experimental/sd_dpo_lora_refinement_pokemon_1x24gb.yaml",
    "chars": 1014,
    "preview": "# Training mode: Direct Preference Optimization LoRA Training\n# Base model:    SD 1.5\n# GPU:           1 x 24GB\n\ntype: S"
  },
  {
    "path": "src/invoke_training/sample_configs/flux_lora_1x40gb.yaml",
    "chars": 1339,
    "preview": "# Training mode: LoRA\n# Base model:    Flux.1-dev\n# Dataset:       Bruce the Gnome\n# GPU:           1 x 40GB\n\ntype: FLUX"
  },
  {
    "path": "src/invoke_training/sample_configs/sd_lora_baroque_1x8gb.yaml",
    "chars": 1567,
    "preview": "# Training mode: Finetuning with LoRA\n# Base model:    SD 1.5\n# Dataset:       https://huggingface.co/datasets/InvokeAI/"
  },
  {
    "path": "src/invoke_training/sample_configs/sd_textual_inversion_gnome_1x8gb.yaml",
    "chars": 1185,
    "preview": "# Training mode: Textual Inversion\n# Base model:    SD v1\n# GPU:           1 x 24GB\n\ntype: SD_TEXTUAL_INVERSION\nseed: 1\n"
  },
  {
    "path": "src/invoke_training/sample_configs/sdxl_finetune_baroque_1x24gb.yaml",
    "chars": 1736,
    "preview": "# Training mode: Full Finetuning\n# Base model:    SDXL\n# Dataset:       https://huggingface.co/datasets/InvokeAI/nga-bar"
  },
  {
    "path": "src/invoke_training/sample_configs/sdxl_finetune_robocats_1x24gb.yaml",
    "chars": 1424,
    "preview": "# Training mode: Full finetune\n# Base model:    SDXL\n# Dataset:       Robocats\n# GPU:           1 x 24GB\n\ntype: SDXL_FIN"
  },
  {
    "path": "src/invoke_training/sample_configs/sdxl_lora_and_ti_gnome_1x24gb.yaml",
    "chars": 1181,
    "preview": "# Training mode: Finetuning with LoRA and Textual Inversion\n# Base model:    SDXL 1.0\n# GPU:           1 x 24GB\n\ntype: S"
  },
  {
    "path": "src/invoke_training/sample_configs/sdxl_lora_baroque_1x24gb.yaml",
    "chars": 1578,
    "preview": "# Training mode: Finetuning with LoRA\n# Base model:    SDXL 1.0\n# Dataset:       https://huggingface.co/datasets/InvokeA"
  },
  {
    "path": "src/invoke_training/sample_configs/sdxl_lora_baroque_1x8gb.yaml",
    "chars": 1890,
    "preview": "# Training mode: Finetuning with LoRA\n# Base model:    SDXL 1.0\n# Dataset:       https://huggingface.co/datasets/InvokeA"
  },
  {
    "path": "src/invoke_training/sample_configs/sdxl_lora_masks_gnome_1x24gb.yaml",
    "chars": 1107,
    "preview": "# Training mode: LoRA with masks\n# Base model:    SDXL 1.0\n# Dataset:       Bruce the Gnome\n# GPU:           1 x 24GB\n\nt"
  },
  {
    "path": "src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml",
    "chars": 1122,
    "preview": "# Training mode: Textual Inversion\n# Base model:    SDXL\n# GPU:           1 x 24GB\n\ntype: SDXL_TEXTUAL_INVERSION\nseed: 1"
  },
  {
    "path": "src/invoke_training/sample_configs/sdxl_textual_inversion_masks_gnome_1x24gb.yaml",
    "chars": 1173,
    "preview": "# Training mode: Textual Inversion with Masks\n# Base model:    SDXL\n# GPU:           1 x 24GB\n\ntype: SDXL_TEXTUAL_INVERS"
  },
  {
    "path": "src/invoke_training/scripts/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/invoke_training/scripts/_experimental/auto_caption/auto_caption_images.py",
    "chars": 4023,
    "preview": "import argparse\nimport json\nfrom pathlib import Path\n\nimport torch\nimport torch.utils.data\nfrom PIL import Image\nfrom tq"
  },
  {
    "path": "src/invoke_training/scripts/_experimental/masks/clipseg.py",
    "chars": 2228,
    "preview": "import torch\nfrom PIL import Image\nfrom transformers import AutoProcessor, CLIPSegForImageSegmentation, CLIPSegProcessor"
  },
  {
    "path": "src/invoke_training/scripts/_experimental/masks/generate_masks.py",
    "chars": 3022,
    "preview": "import argparse\nfrom pathlib import Path\n\nimport torch\nimport torch.utils.data\nfrom tqdm import tqdm\n\nfrom invoke_traini"
  },
  {
    "path": "src/invoke_training/scripts/_experimental/masks/generate_masks_for_jsonl_dataset.py",
    "chars": 5195,
    "preview": "import argparse\nfrom pathlib import Path\n\nimport torch\nimport torch.utils.data\nfrom tqdm import tqdm\n\nfrom invoke_traini"
  },
  {
    "path": "src/invoke_training/scripts/_experimental/rank_images.py",
    "chars": 3899,
    "preview": "import argparse\nimport os\nimport time\nfrom pathlib import Path\nfrom typing import Literal\n\nimport gradio as gr\nimport ya"
  },
  {
    "path": "src/invoke_training/scripts/convert_sd_lora_to_kohya_format.py",
    "chars": 1475,
    "preview": "import argparse\nfrom pathlib import Path\n\nimport torch\n\nfrom invoke_training._shared.stable_diffusion.lora_checkpoint_ut"
  },
  {
    "path": "src/invoke_training/scripts/invoke_generate_images.py",
    "chars": 4689,
    "preview": "import argparse\nfrom pathlib import Path\n\nfrom invoke_training._shared.stable_diffusion.model_loading_utils import Pipel"
  },
  {
    "path": "src/invoke_training/scripts/invoke_train.py",
    "chars": 846,
    "preview": "import argparse\nfrom pathlib import Path\n\nimport yaml\nfrom pydantic import TypeAdapter\n\nfrom invoke_training.config.pipe"
  },
  {
    "path": "src/invoke_training/scripts/invoke_train_ui.py",
    "chars": 567,
    "preview": "import argparse\n\nimport uvicorn\n\nfrom invoke_training.ui.app import build_app\n\n\ndef main():\n    parser = argparse.Argume"
  },
  {
    "path": "src/invoke_training/scripts/invoke_visualize_data_loading.py",
    "chars": 4488,
    "preview": "import argparse\nimport os\nimport time\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nimport yaml\nfrom PIL imp"
  },
  {
    "path": "src/invoke_training/scripts/utils/image_dir_dataset.py",
    "chars": 1756,
    "preview": "import os\nimport typing\n\nimport torch\nfrom PIL import Image\n\n\nclass ImageDirDataset(torch.utils.data.Dataset):\n    \"\"\"A "
  },
  {
    "path": "src/invoke_training/ui/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/invoke_training/ui/app.py",
    "chars": 875,
    "preview": "from pathlib import Path\n\nimport gradio as gr\nfrom fastapi import FastAPI\nfrom fastapi.responses import FileResponse\nfro"
  },
  {
    "path": "src/invoke_training/ui/config_groups/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/invoke_training/ui/config_groups/aspect_ratio_bucket_config_group.py",
    "chars": 2725,
    "preview": "from typing import Any\n\nimport gradio as gr\n\nfrom invoke_training.config.data.data_loader_config import AspectRatioBucke"
  },
  {
    "path": "src/invoke_training/ui/config_groups/base_pipeline_config_group.py",
    "chars": 6243,
    "preview": "from typing import Any\n\nimport gradio as gr\n\nfrom invoke_training.config.base_pipeline_config import BasePipelineConfig\n"
  },
  {
    "path": "src/invoke_training/ui/config_groups/dataset_config_group.py",
    "chars": 12854,
    "preview": "from typing import Any\n\nimport gradio as gr\n\nfrom invoke_training.config.data.dataset_config import (\n    HFHubImageCapt"
  },
  {
    "path": "src/invoke_training/ui/config_groups/flux_lora_config_group.py",
    "chars": 20273,
    "preview": "import typing\n\nimport gradio as gr\n\nfrom invoke_training.pipelines.flux.lora.config import FluxLoraConfig\nfrom invoke_tr"
  },
  {
    "path": "src/invoke_training/ui/config_groups/image_caption_sd_data_loader_config_group.py",
    "chars": 5961,
    "preview": "from typing import Any\n\nimport gradio as gr\n\nfrom invoke_training.config.data.data_loader_config import ImageCaptionSDDa"
  },
  {
    "path": "src/invoke_training/ui/config_groups/optimizer_config_group.py",
    "chars": 6851,
    "preview": "from typing import Any\n\nimport gradio as gr\n\nfrom invoke_training.config.optimizer.optimizer_config import AdamOptimizer"
  },
  {
    "path": "src/invoke_training/ui/config_groups/sd_lora_config_group.py",
    "chars": 14076,
    "preview": "import typing\n\nimport gradio as gr\n\nfrom invoke_training.pipelines.stable_diffusion.lora.config import SdLoraConfig\nfrom"
  },
  {
    "path": "src/invoke_training/ui/config_groups/sd_textual_inversion_config_group.py",
    "chars": 13704,
    "preview": "import typing\n\nimport gradio as gr\n\nfrom invoke_training.pipelines.stable_diffusion.textual_inversion.config import SdTe"
  },
  {
    "path": "src/invoke_training/ui/config_groups/sdxl_finetune_config_group.py",
    "chars": 13577,
    "preview": "import typing\n\nimport gradio as gr\n\nfrom invoke_training.pipelines.stable_diffusion_xl.finetune.config import SdxlFinetu"
  },
  {
    "path": "src/invoke_training/ui/config_groups/sdxl_lora_and_textual_inversion_config_group.py",
    "chars": 17946,
    "preview": "import typing\n\nimport gradio as gr\n\nfrom invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config"
  },
  {
    "path": "src/invoke_training/ui/config_groups/sdxl_lora_config_group.py",
    "chars": 14513,
    "preview": "import typing\n\nimport gradio as gr\n\nfrom invoke_training.pipelines.stable_diffusion_xl.lora.config import SdxlLoraConfig"
  },
  {
    "path": "src/invoke_training/ui/config_groups/sdxl_textual_inversion_config_group.py",
    "chars": 14131,
    "preview": "import typing\n\nimport gradio as gr\n\nfrom invoke_training.pipelines.stable_diffusion_xl.textual_inversion.config import S"
  },
  {
    "path": "src/invoke_training/ui/config_groups/textual_inversion_sd_data_loader_config_group.py",
    "chars": 7577,
    "preview": "from typing import Any\n\nimport gradio as gr\n\nfrom invoke_training.config.data.data_loader_config import (\n    TextualInv"
  },
  {
    "path": "src/invoke_training/ui/config_groups/ui_config_element.py",
    "chars": 1662,
    "preview": "from typing import Any\n\nimport gradio as gr\n\n\nclass UIConfigElement:\n    \"\"\"A base class for UI blocks that represent a "
  },
  {
    "path": "src/invoke_training/ui/gradio_blocks/header.py",
    "chars": 598,
    "preview": "import gradio as gr\n\nfrom invoke_training.ui.utils.utils import get_assets_dir_path\n\n\nclass Header:\n    def __init__(sel"
  },
  {
    "path": "src/invoke_training/ui/gradio_blocks/pipeline_tab.py",
    "chars": 6915,
    "preview": "import typing\n\nimport gradio as gr\nimport yaml\n\nfrom invoke_training.config.pipeline_config import PipelineConfig\nfrom i"
  },
  {
    "path": "src/invoke_training/ui/index.html",
    "chars": 2058,
    "preview": "<!DOCTYPE html>\n<html lang=\"en\">\n<head>\n    <meta charset=\"UTF-8\">\n    <meta name=\"viewport\" content=\"width=device-width"
  },
  {
    "path": "src/invoke_training/ui/pages/data_page.py",
    "chars": 14842,
    "preview": "from pathlib import Path\n\nimport gradio as gr\nfrom PIL import Image\n\nfrom invoke_training._shared.data.datasets.image_ca"
  },
  {
    "path": "src/invoke_training/ui/pages/training_page.py",
    "chars": 8100,
    "preview": "import os\nimport subprocess\nimport tempfile\nimport time\n\nimport gradio as gr\nimport yaml\n\nfrom invoke_training.config.pi"
  },
  {
    "path": "src/invoke_training/ui/utils/prompts.py",
    "chars": 3059,
    "preview": "NEGATIVE_PROMPT_DELIMITER = \"[NEG]\"\n\n\ndef split_pos_neg_prompts(prompt: str) -> tuple[str, str]:\n    \"\"\"Split a prompt c"
  },
  {
    "path": "src/invoke_training/ui/utils/utils.py",
    "chars": 1016,
    "preview": "import typing\nfrom pathlib import Path\n\nimport yaml\nfrom pydantic import TypeAdapter\n\nfrom invoke_training.config.pipeli"
  },
  {
    "path": "tests/invoke_training/_shared/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "tests/invoke_training/_shared/checkpoints/test_checkpoint_tracker.py",
    "chars": 4020,
    "preview": "import os\nimport tempfile\nfrom pathlib import Path\n\nimport pytest\n\nfrom invoke_training._shared.checkpoints.checkpoint_t"
  },
  {
    "path": "tests/invoke_training/_shared/checkpoints/test_serialization.py",
    "chars": 1217,
    "preview": "import os\nimport tempfile\n\nimport pytest\nimport torch\n\nfrom invoke_training._shared.checkpoints.serialization import (\n "
  },
  {
    "path": "tests/invoke_training/_shared/data/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "tests/invoke_training/_shared/data/data_loaders/__init__.py",
    "chars": 0,
    "preview": ""
  }
]

// ... and 46 more files (download for full content)

About this extraction

This page contains the full source code of the invoke-ai/invoke-training GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 246 files (930.6 KB), approximately 219.8k tokens, and a symbol index with 581 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!