Showing preview only (1,010K chars total). Download the full file or copy to clipboard to get everything.
Repository: invoke-ai/invoke-training
Branch: main
Commit: 363f83cdb5e6
Files: 246
Total size: 930.6 KB
Directory structure:
gitextract_5kosyax7/
├── .github/
│ └── workflows/
│ ├── deploy.yaml
│ └── test.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── docs/
│ ├── contributing/
│ │ ├── development_environment.md
│ │ ├── directory_structure.md
│ │ ├── documentation.md
│ │ └── tests.md
│ ├── get-started/
│ │ ├── installation.md
│ │ └── quick-start.md
│ ├── guides/
│ │ ├── dataset_formats.md
│ │ ├── model_merge.md
│ │ └── stable_diffusion/
│ │ ├── dpo_lora_sd.md
│ │ ├── gnome_lora_masks_sdxl.md
│ │ ├── robocats_finetune_sdxl.md
│ │ └── textual_inversion_sdxl.md
│ ├── index.md
│ ├── reference/
│ │ └── config/
│ │ ├── index.md
│ │ ├── pipelines/
│ │ │ ├── sd_lora.md
│ │ │ ├── sd_textual_inversion.md
│ │ │ ├── sdxl_finetune.md
│ │ │ ├── sdxl_lora.md
│ │ │ ├── sdxl_lora_and_textual_inversion.md
│ │ │ └── sdxl_textual_inversion.md
│ │ └── shared/
│ │ ├── data/
│ │ │ ├── data_loader_config.md
│ │ │ └── dataset_config.md
│ │ └── optimizer_config.md
│ └── templates/
│ └── python/
│ └── material/
│ └── labels.html
├── mkdocs.yml
├── pyproject.toml
├── sample_data/
│ └── bruce_the_gnome/
│ └── data.jsonl
├── src/
│ └── invoke_training/
│ ├── __init__.py
│ ├── _shared/
│ │ ├── __init__.py
│ │ ├── accelerator/
│ │ │ ├── __init__.py
│ │ │ └── accelerator_utils.py
│ │ ├── checkpoints/
│ │ │ ├── __init__.py
│ │ │ ├── checkpoint_tracker.py
│ │ │ ├── lora_checkpoint_utils.py
│ │ │ └── serialization.py
│ │ ├── data/
│ │ │ ├── ARCHITECTURE.md
│ │ │ ├── __init__.py
│ │ │ ├── data_loaders/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── dreambooth_sd_dataloader.py
│ │ │ │ ├── image_caption_flux_dataloader.py
│ │ │ │ ├── image_caption_sd_dataloader.py
│ │ │ │ ├── image_pair_preference_sd_dataloader.py
│ │ │ │ └── textual_inversion_sd_dataloader.py
│ │ │ ├── datasets/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── build_dataset.py
│ │ │ │ ├── hf_image_caption_dataset.py
│ │ │ │ ├── hf_image_pair_preference_dataset.py
│ │ │ │ ├── image_caption_dir_dataset.py
│ │ │ │ ├── image_caption_jsonl_dataset.py
│ │ │ │ ├── image_dir_dataset.py
│ │ │ │ ├── image_pair_preference_dataset.py
│ │ │ │ └── transform_dataset.py
│ │ │ ├── samplers/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── aspect_ratio_bucket_batch_sampler.py
│ │ │ │ ├── batch_offset_sampler.py
│ │ │ │ ├── concat_sampler.py
│ │ │ │ ├── interleaved_sampler.py
│ │ │ │ └── offset_sampler.py
│ │ │ ├── transforms/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── caption_prefix_transform.py
│ │ │ │ ├── concat_fields_transform.py
│ │ │ │ ├── constant_field_transform.py
│ │ │ │ ├── drop_field_transform.py
│ │ │ │ ├── flux_image_transform.py
│ │ │ │ ├── load_cache_transform.py
│ │ │ │ ├── sd_image_transform.py
│ │ │ │ ├── shuffle_caption_transform.py
│ │ │ │ ├── template_caption_transform.py
│ │ │ │ └── tensor_disk_cache.py
│ │ │ └── utils/
│ │ │ ├── __init__.py
│ │ │ ├── aspect_ratio_bucket_manager.py
│ │ │ ├── resize.py
│ │ │ └── resolution.py
│ │ ├── flux/
│ │ │ ├── encoding_utils.py
│ │ │ ├── lora_checkpoint_utils.py
│ │ │ ├── model_loading_utils.py
│ │ │ └── validation.py
│ │ ├── optimizer/
│ │ │ ├── __init__.py
│ │ │ └── optimizer_utils.py
│ │ ├── stable_diffusion/
│ │ │ ├── __init__.py
│ │ │ ├── base_model_version.py
│ │ │ ├── checkpoint_utils.py
│ │ │ ├── lora_checkpoint_utils.py
│ │ │ ├── min_snr_weighting.py
│ │ │ ├── model_loading_utils.py
│ │ │ ├── textual_inversion.py
│ │ │ ├── tokenize_captions.py
│ │ │ └── validation.py
│ │ ├── tools/
│ │ │ ├── __init__.py
│ │ │ └── generate_images.py
│ │ └── utils/
│ │ ├── import_xformers.py
│ │ └── jsonl.py
│ ├── config/
│ │ ├── __init__.py
│ │ ├── base_pipeline_config.py
│ │ ├── config_base_model.py
│ │ ├── data/
│ │ │ ├── __init__.py
│ │ │ ├── data_loader_config.py
│ │ │ └── dataset_config.py
│ │ ├── optimizer/
│ │ │ ├── __init__.py
│ │ │ └── optimizer_config.py
│ │ └── pipeline_config.py
│ ├── model_merge/
│ │ ├── __init__.py
│ │ ├── extract_lora.py
│ │ ├── merge_models.py
│ │ ├── merge_tasks_to_base.py
│ │ ├── scripts/
│ │ │ ├── extract_lora_from_model_diff.py
│ │ │ ├── merge_lora_into_model.py
│ │ │ ├── merge_models.py
│ │ │ └── merge_task_models_to_base_model.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ ├── normalize_weights.py
│ │ └── parse_model_arg.py
│ ├── pipelines/
│ │ ├── __init__.py
│ │ ├── _experimental/
│ │ │ └── sd_dpo_lora/
│ │ │ ├── config.py
│ │ │ └── train.py
│ │ ├── callbacks.py
│ │ ├── flux/
│ │ │ └── lora/
│ │ │ ├── __init__.py
│ │ │ ├── config.py
│ │ │ └── train.py
│ │ ├── invoke_train.py
│ │ ├── stable_diffusion/
│ │ │ ├── __init__.py
│ │ │ ├── lora/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── config.py
│ │ │ │ └── train.py
│ │ │ └── textual_inversion/
│ │ │ ├── __init__.py
│ │ │ ├── config.py
│ │ │ └── train.py
│ │ └── stable_diffusion_xl/
│ │ ├── __init__.py
│ │ ├── finetune/
│ │ │ ├── __init__.py
│ │ │ ├── config.py
│ │ │ └── train.py
│ │ ├── lora/
│ │ │ ├── __init__.py
│ │ │ ├── config.py
│ │ │ └── train.py
│ │ ├── lora_and_textual_inversion/
│ │ │ ├── __init__.py
│ │ │ ├── config.py
│ │ │ └── train.py
│ │ └── textual_inversion/
│ │ ├── __init__.py
│ │ ├── config.py
│ │ └── train.py
│ ├── sample_configs/
│ │ ├── _experimental/
│ │ │ ├── sd_dpo_lora_pickapic_1x24gb.yaml
│ │ │ └── sd_dpo_lora_refinement_pokemon_1x24gb.yaml
│ │ ├── flux_lora_1x40gb.yaml
│ │ ├── sd_lora_baroque_1x8gb.yaml
│ │ ├── sd_textual_inversion_gnome_1x8gb.yaml
│ │ ├── sdxl_finetune_baroque_1x24gb.yaml
│ │ ├── sdxl_finetune_robocats_1x24gb.yaml
│ │ ├── sdxl_lora_and_ti_gnome_1x24gb.yaml
│ │ ├── sdxl_lora_baroque_1x24gb.yaml
│ │ ├── sdxl_lora_baroque_1x8gb.yaml
│ │ ├── sdxl_lora_masks_gnome_1x24gb.yaml
│ │ ├── sdxl_textual_inversion_gnome_1x24gb.yaml
│ │ └── sdxl_textual_inversion_masks_gnome_1x24gb.yaml
│ ├── scripts/
│ │ ├── __init__.py
│ │ ├── _experimental/
│ │ │ ├── auto_caption/
│ │ │ │ └── auto_caption_images.py
│ │ │ ├── masks/
│ │ │ │ ├── clipseg.py
│ │ │ │ ├── generate_masks.py
│ │ │ │ └── generate_masks_for_jsonl_dataset.py
│ │ │ └── rank_images.py
│ │ ├── convert_sd_lora_to_kohya_format.py
│ │ ├── invoke_generate_images.py
│ │ ├── invoke_train.py
│ │ ├── invoke_train_ui.py
│ │ ├── invoke_visualize_data_loading.py
│ │ └── utils/
│ │ └── image_dir_dataset.py
│ └── ui/
│ ├── __init__.py
│ ├── app.py
│ ├── config_groups/
│ │ ├── __init__.py
│ │ ├── aspect_ratio_bucket_config_group.py
│ │ ├── base_pipeline_config_group.py
│ │ ├── dataset_config_group.py
│ │ ├── flux_lora_config_group.py
│ │ ├── image_caption_sd_data_loader_config_group.py
│ │ ├── optimizer_config_group.py
│ │ ├── sd_lora_config_group.py
│ │ ├── sd_textual_inversion_config_group.py
│ │ ├── sdxl_finetune_config_group.py
│ │ ├── sdxl_lora_and_textual_inversion_config_group.py
│ │ ├── sdxl_lora_config_group.py
│ │ ├── sdxl_textual_inversion_config_group.py
│ │ ├── textual_inversion_sd_data_loader_config_group.py
│ │ └── ui_config_element.py
│ ├── gradio_blocks/
│ │ ├── header.py
│ │ └── pipeline_tab.py
│ ├── index.html
│ ├── pages/
│ │ ├── data_page.py
│ │ └── training_page.py
│ └── utils/
│ ├── prompts.py
│ └── utils.py
└── tests/
└── invoke_training/
├── _shared/
│ ├── __init__.py
│ ├── checkpoints/
│ │ ├── test_checkpoint_tracker.py
│ │ └── test_serialization.py
│ ├── data/
│ │ ├── __init__.py
│ │ ├── data_loaders/
│ │ │ ├── __init__.py
│ │ │ ├── test_dreambooth_sd_dataloader.py
│ │ │ ├── test_image_caption_sd_dataloader.py
│ │ │ ├── test_image_pair_preference_sd_dataloader.py
│ │ │ └── test_textual_inversion_sd_dataloader.py
│ │ ├── dataset_fixtures.py
│ │ ├── datasets/
│ │ │ ├── __init__.py
│ │ │ ├── test_hf_image_caption_dataset.py
│ │ │ ├── test_hf_image_pair_preference_dataset.py
│ │ │ ├── test_image_caption_dir_dataset.py
│ │ │ ├── test_image_caption_jsonl_dataset.py
│ │ │ ├── test_image_dir_dataset.py
│ │ │ ├── test_image_pair_preference_dataset.py
│ │ │ └── test_transform_dataset.py
│ │ ├── samplers/
│ │ │ ├── __init__.py
│ │ │ ├── test_aspect_ratio_bucket_batch_sampler.py
│ │ │ ├── test_batch_offset_sampler.py
│ │ │ ├── test_concat_sampler.py
│ │ │ ├── test_interleaved_sampler.py
│ │ │ └── test_offset_sampler.py
│ │ ├── transforms/
│ │ │ ├── __init__.py
│ │ │ ├── test_caption_prefix_transform.py
│ │ │ ├── test_concat_fields_transform.py
│ │ │ ├── test_constant_field_transform.py
│ │ │ ├── test_drop_field_transform.py
│ │ │ ├── test_load_cache_transform.py
│ │ │ ├── test_sd_image_transform.py
│ │ │ ├── test_shuffle_caption_transform.py
│ │ │ ├── test_template_caption_transform.py
│ │ │ └── test_tensor_disk_cache.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ ├── test_aspect_ratio_bucket_manager.py
│ │ ├── test_resize.py
│ │ └── test_resolution.py
│ ├── stable_diffusion/
│ │ ├── __init__.py
│ │ ├── test_base_model_version.py
│ │ ├── test_lora_checkpoint_utils.py
│ │ ├── test_model_loading_utils.py
│ │ ├── test_textual_inversion.py
│ │ └── ti_embedding_checkpoint_fixture.py
│ └── utils/
│ └── test_jsonl.py
├── config/
│ └── pipelines/
│ └── test_pipeline_config.py
├── model_merge/
│ ├── __init__.py
│ ├── test_merge_models.py
│ ├── test_merge_tasks_to_base.py
│ └── utils.py
└── ui/
└── utils/
└── test_prompts.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/workflows/deploy.yaml
================================================
name: Deploy invoke-training docs
on:
push:
branches:
- main
permissions:
contents: write
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Configure Git Credentials
run: |
git config user.name github-actions[bot]
git config user.email 41898282+github-actions[bot]@users.noreply.github.com
- uses: actions/setup-python@v4
with:
python-version: "3.10"
cache: pip
cache-dependency-path: pyproject.toml
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install .[test]
- run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
- uses: actions/cache@v3
with:
key: mkdocs-material-${{ env.cache_id }}
path: .cache
restore-keys: |
mkdocs-material-
- run: mkdocs gh-deploy --force
================================================
FILE: .github/workflows/test.yaml
================================================
name: Test invoke-training
on:
push:
branches:
- main
pull_request:
workflow_dispatch:
jobs:
build:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.12"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: pip
cache-dependency-path: pyproject.toml
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install .[test]
- name: Ruff lint
run: |
ruff check --output-format=github .
- name: Ruff format
run: |
ruff format --check .
- name: Test with pytest
run: |
pytest tests --junitxml=junit/test-results-${{ matrix.python-version }}.xml -m "not cuda and not loads_model"
- name: Upload pytest test results
uses: actions/upload-artifact@v4
with:
name: pytest-results-${{ matrix.python-version }}
path: junit/test-results-${{ matrix.python-version }}.xml
# Use always() to always run this step to publish test results when there are test failures.
if: ${{ always() }}
================================================
FILE: .gitignore
================================================
/output/
/test_configs/
/data/
# pyenv
.python-version
# VSCode
.vscode/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
junit/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.aider*
================================================
FILE: .pre-commit-config.yaml
================================================
# See https://pre-commit.com/ for usage and config.
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.1.7
hooks:
# Run the linter.
- id: ruff
# Run the formatter.
- id: ruff-format
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# invoke-training
A library for training custom Stable Diffusion models (fine-tuning, LoRA training, textual inversion, etc.) that can be used in [InvokeAI](https://github.com/invoke-ai/InvokeAI).
> [!WARNING] > `invoke-training` is still under active development, and breaking changes are likely. Full backwards compatibility will not be guaranteed until v1.0.0.
> In the meantime, I recommend pinning to a specific commit hash.
## Documentation
<https://invoke-ai.github.io/invoke-training/>
## Training Modes
- Stable Diffusion
- LoRA
- DreamBooth LoRA
- Textual Inversion
- Stable Diffusion XL
- Full finetuning
- LoRA
- DreamBooth LoRA
- Textual Inversion
- LoRA and Textual Inversion
More training modes coming soon!
## Installation
See the [Installation](https://invoke-ai.github.io/invoke-training/get-started/installation/) section of the documentation.
## Quick Start
`invoke-training` pipelines can be configured and launched from either the CLI or the GUI.
### CLI
Run training via the CLI with type-checked YAML configuration files for maximum control:
```bash
invoke-train --cfg-file src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml
```
### GUI
Run training via the GUI for a simpler starting point.
```bash
invoke-train-ui
# Or, you can optionally override the default host and port:
invoke-train-ui --host 0.0.0.0 --port 1234
```
## Features
Training progress can be monitored with [Tensorboard](https://www.tensorflow.org/tensorboard):

_Validation images in the Tensorboard UI._
All trained models are compatible with InvokeAI:

_Example image generated with the prompt "A cute yoda pokemon creature." and a trained Pokemon LoRA._
## Contributing
Contributors are welcome. For developer guidance, see the [Contributing](https://invoke-ai.github.io/invoke-training/contributing/development_environment/) section of the documentation.
================================================
FILE: docs/contributing/development_environment.md
================================================
# Development Environment Setup
See the [developer installation instructions](../get-started/installation.md#developer-installation).
================================================
FILE: docs/contributing/directory_structure.md
================================================
# Directory Structure
```bash
invoke-training/
├── README.md
├── docs/
├── src/
│ └── invoke-training/
│ ├── _shared/ # Utilities shared across multiple pipelines. Hight unit test coverage.
│ ├── config/ # Config structures shared by multiple pipelines.
│ ├── pipelines/ # Each pipeline is isolated in it's own directory with a train.py and config.py.
│ │ ├── stable_diffusion/
│ │ │ ├── lora/
│ │ │ │ ├── config.py
│ │ │ │ └── train.py
│ │ │ └── textual_inversion/
│ │ │ └── ...
│ │ ├── stable_diffusion_xl/
│ │ └── ...
│ └── scripts/ # Main entrypoints.
└── tests/ # Mirrors src/ directory.
```
================================================
FILE: docs/contributing/documentation.md
================================================
# Documentation
The documentation site is generated using [mkdocs](https://www.mkdocs.org/) and [mkdocstrings-python](https://mkdocstrings.github.io/python/).
To view your documentation changes locally, run `mkdocs serve`.
================================================
FILE: docs/contributing/tests.md
================================================
# Tests
Run all unit tests with:
```bash
pytest tests/
```
There are some test 'markers' defined in [pyproject.toml](https://github.com/invoke-ai/invoke-training/blob/main/pyproject.toml) that can be used to skip some tests. For example, the following command skips tests that require a GPU or require downloading model weights:
```bash
pytest tests/ -m "not cuda and not loads_model"
```
================================================
FILE: docs/get-started/installation.md
================================================
# Installation
## Requirements
1. Python 3.10, 3.11 and 3.12 are currently supported. Check your Python version by running `python -V`.
2. An NVIDIA GPU with >= 8 GB VRAM is recommended for model training.
## Basic Installation
0. Open your terminal and navigate to the directory where you want to clone the `invoke-training` repo.
1. Clone the repo:
```bash
git clone https://github.com/invoke-ai/invoke-training.git
```
2. Create and activate a python [virtual environment](https://docs.python.org/3/library/venv.html#creating-virtual-environments). This creates an isolated environment for `invoke-training` and its dependencies that won't interfere with other python environments on your system, including any installations of [InvokeAI](https://www.github.com/invoke-ai/invokeai).
```bash
# Navigate to the invoke-training directory.
cd invoke-training
# Create a new virtual environment named `invoketraining`.
python -m venv invoketraining
# Activate the new virtual environment.
# On Windows:
.\invoketraining\Scripts\activate
# On MacOS / Linux:
source invoketraining/bin/activate
```
3. Install `invoke-training` and its dependencies. Run the appropriate install command for your system.
```bash
# A recent version of pip is required, so first upgrade pip:
python -m pip install --upgrade pip
# Install - Windows or Linux with a Nvidia GPU:
pip install ".[test]" --extra-index-url https://download.pytorch.org/whl/cu126
# Install - Linux with no GPU:
pip install ".[test]" --extra-index-url https://download.pytorch.org/whl/cpu
# Install - All other systems:
pip install ".[test]"
```
In the future, before you run `invoke-training`, you must activate the virtual environment you created during installation, using the same command you used during installation.
## Developer Installation
Consider forking the repo if you plan to contribute code changes.
Follow the above installation instructions, cloning your fork instead of this repo if you made a fork.
Next, we suggest setting up the repo's pre-commit hooks to automatically format and lint your contributions:
1. (_Optional_) Install the pre-commit hooks: `pre-commit install`. This will run static analysis tools (ruff) on `git commit`.
2. (_Optional_) Setup `ruff` in your IDE of choice.
================================================
FILE: docs/get-started/quick-start.md
================================================
# Quick Start
`invoke-training` has both a GUI and a CLI (for advanced users). The instructions for getting started with both options can be found on this page.
There is also a video introduction to `invoke-training`:
<iframe width="560" height="315" src="https://www.youtube.com/embed/OZIz2vvtlM4?si=iR73F0IhlsolyYAl" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" referrerpolicy="strict-origin-when-cross-origin" allowfullscreen></iframe>
## Quick Start - GUI
### 1. Installation
Follow the [`invoke-training` installation instructions](./installation.md).
### 2. Launch the GUI
Activate the virtual environment you created during installation, using the same command you used during installation.
You'll need to do this every time you run `invoke-training`.
```bash
# From the invoke-training directory:
invoke-train-ui
# Or, you can optionally override the default host and port:
invoke-train-ui --host 0.0.0.0 --port 1234
```
Access the GUI in your browser at the URL printed to the console.
### 3. Configure the training job
Select the desired training pipeline type in the top-level tab.
For this tutorial, we don't need to change any of the configuration values. The preset configuration should work well.
### 4. Generate the YAML configuration
Click on 'Generate Config' to generate a YAML configuration file. This YAML configuration file could be used to launch the training job from the CLI, if desired.
### 5. Start training
Click on the 'Start Training' and check your terminal for progress logs.
### 6. Monitor training
Monitor the training process with Tensorboard by running `tensorboard --logdir output/` and visiting [localhost:6006](http://localhost:6006) in your browser. Here you can see generated validation images throughout the training process.

_Validation images in the Tensorboard UI._
### 7. Invokeai
Select a checkpoint based on the quality of the generated images.
If you haven't already, setup [InvokeAI](https://github.com/invoke-ai/InvokeAI) by following its documentation.
Copy your selected LoRA checkpoint into your `${INVOKEAI_ROOT}/autoimport/lora` directory. For example:
```bash
# Note: You will have to replace the timestamp in the checkpoint path.
cp output/1691088769.5694647/checkpoint_epoch-00000002.safetensors ${INVOKEAI_ROOT}/autoimport/lora/pokemon_epoch-00000002.safetensors
```
You can now use your trained Pokemon LoRA in the InvokeAI UI! 🎉

_Example image generated with the prompt "A cute yoda pokemon creature." and Pokemon LoRA._
## Quick Start - CLI
### 1. Installation
Follow the [`invoke-training` installation instructions](./installation.md).
### 2. Training
Activate the virtual environment you created during installation, using the same command you used during installation.
You'll need to do this every time you run `invoke-training`.
See the [Textual Inversion - SDXL](../guides/stable_diffusion/textual_inversion_sdxl.md) tutorial for instructions on how to train a model via the CLI.
================================================
FILE: docs/guides/dataset_formats.md
================================================
# Dataset Formats
`invoke-training` supports the following dataset formats:
- `IMAGE_CAPTION_JSONL_DATASET`: A local image-caption dataset described by a single `.jsonl` file.
- `IMAGE_CAPTION_DIR_DATASET`: A local directory of images with associated `.txt` caption files.
- `IMAGE_DIR_DATASET`: A local directory of images (without captions).
- `HF_HUB_IMAGE_CAPTION_DATASET`: A Hugging Face Hub dataset containing images and captions.
See the documentation for a particular training pipeline to see which dataset formats it supports.
The following sections explain each of these formats in more detail.
## `IMAGE_CAPTION_JSONL_DATASET`
Config documentation: [ImageCaptionJsonlDatasetConfig][invoke_training.config.data.dataset_config.ImageCaptionJsonlDatasetConfig]
A `IMAGE_CAPTION_JSONL_DATASET` consists of a single `.jsonl` file containing image paths and associated captions.
Sample directory structure:
```bash
my_custom_dataset/
├── data.jsonl
└── train/
├── 0001.png
├── 0002.png
├── 0003.png
└── ...
```
The contents of `data.jsonl` would be:
```json
{"file_name": "train/0001.png", "text": "This is a caption describing image 0001."}
{"file_name": "train/0002.png", "text": "This is a caption describing image 0002."}
{"file_name": "train/0003.png", "text": "This is a caption describing image 0003."}
```
The image file paths can be either absolute paths, or relative to the `.jsonl` file.
Finally, this dataset can be used with the following pipeline dataset configuration:
```yaml
type: IMAGE_CAPTION_JSONL_DATASET
jsonl_path: /path/to/my_custom_dataset/metadata.jsonl
image_column: file_name
caption_column: text
```
A useful characteristic of this dataset format is that a `.jsonl` file can reference an image file anywhere on the local disk. It is common to maintain multiple `.jsonl` datasets that reference some of the same images without needing multiple copies of those images on disk.
## `IMAGE_CAPTION_DIR_DATASET`
Config documentation: [ImageCaptionDirDataset][invoke_training.config.data.dataset_config.ImageCaptionDirDatasetConfig]
A `IMAGE_CAPTION_DIR_DATASET` consists of a directory of image files and corresponding `.txt` caption files of the same name.
Sample directory structure:
```bash
my_custom_dataset/
├── 0001.png
├── 0001.txt
├── 0002.jpg
├── 0002.txt
├── 0003.png
├── 0003.txt
└── ...
```
Each `.txt` file should contain a caption on the first line of the file. Here are the sample contents of `0001.txt`:
```txt title="0001.txt"
this is a caption for example 0001
```
This dataset can be used with the following pipeline dataset configuration:
```yaml
type: IMAGE_CAPTION_DIR_DATASET
dataset_dir: /path/to/my_custom_dataset
```
## `IMAGE_DIR_DATASET`
Config documentation: [ImageDirDataset][invoke_training.config.data.dataset_config.ImageDirDatasetConfig]
A `IMAGE_DIR_DATASET` consists of a single directory of images (without captions).
Sample directory structure:
```bash
my_custom_dataset/
├── 0001.png
├── 0002.jpg
├── 0003.png
└── ...
```
This dataset can be used with the following pipeline dataset configuration:
```yaml
type: IMAGE_DIR_DATASET
dataset_dir: /path/to/my_custom_dataset
```
## `HF_HUB_IMAGE_CAPTION_DATASET`
Config documentation: [HFHubImageCaptionDatasetConfig][invoke_training.config.data.dataset_config.HFHubImageCaptionDatasetConfig]
The `HF_HUB_IMAGE_CAPTION_DATASET` dataset format can be used to access publicly datasets on the [Hugging Face Hub](https://huggingface.co/datasets). You can filter for the `Text-to-Image` task to find relevant datasets that contain both an image column and a caption column. [lambdalabs/pokemon-blip-captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) is a popular choice if you're not sure where to start.
================================================
FILE: docs/guides/model_merge.md
================================================
# Model Merging
`invoke-training` provides utility scripts for several common model merging workflows. This page contains a summary of the available tools.
## `extract_lora_from_model_diff.py`
Extract a LoRA model that represents the difference between two base models.
Note that the extracted LoRA model is a lossy representation of the difference between the models, so some degradation in quality is expected.
For usage docs, run:
```bash
python src/invoke_training/model_merge/scripts/extract_lora_from_model_diff.py -h
```
## `merge_lora_into_model.py`
Merge a LoRA model into a base model to produce a new base model.
For usage docs, run:
```bash
python src/invoke_training/model_merge/scripts/merge_lora_into_model.py -h
```
## `merge_models.py`
Merge 2 or more base models to produce a single base model (using either LERP or SLERP). This is a simple merge strategy that merges all model weights in the same way.
For usage docs, run:
```bash
python src/invoke_training/model_merge/scripts/merge_models.py -h
```
## `merge_task_models_to_base_model.py`
Merge 1 or more task-specific base models into a single starting base model (using either [TIES](https://arxiv.org/abs/2306.01708) or [DARE](https://arxiv.org/abs/2311.03099)). This merge strategy aims to preserve the task-specific behaviors of the task models while making only small changes to the original base model. This approach enables multiple task models to be merged without excessive interference between them.
If you want to merge a task-specific LoRA into a base model using this strategy, first use `merge_lora_into_model.py` to produce a task-specific base model, then merge that new base model using this strategy.
For usage docs, run:
```bash
python src/invoke_training/model_merge/scripts/merge_task_models_to_base_model.py -h
```
================================================
FILE: docs/guides/stable_diffusion/dpo_lora_sd.md
================================================
# (Experimental) Diffusion DPO - SD
!!! tip "Experimental"
The Diffusion Direct Preference Optimization training pipeline is still experimental. Support may be dropped at any time.
This tutorial walks through some initial experiments around using Diffusion Direct Preference Optimization (DPO) ([paper](https://arxiv.org/abs/2311.12908)) to train Stable Diffusion LoRA models.
## Experiment 1: `pickapic_v2` LoRA Training
The Diffusion-DPO paper does full model fine-tuning on the [pickapic_v2](https://huggingface.co/datasets/yuvalkirstain/pickapic_v2) dataset, which consists of roughly 1M AI-generated image pairs with preference annotations. In this experiment, we attempt to fine-tune a Stable Diffusion LoRA model using a small subset of the pickapic_v2 dataset.
Run this experiment with the following command:
```bash
invoke-train -c src/invoke_training/sample_configs/_experimental/sd_dpo_lora_pickapic_1x24gb.yaml
```
Here is a cherry-picked example of a prompt for which this training process was clearly beneficial.
Prompt: "*A galaxy-colored figurine is floating over the sea at sunset, photorealistic*"
| Before DPO Training | After DPO Training (same seed)|
| - | - |
|  |  |
## Experiment 2: LoRA Model Refinement
As a second experiment, we attempt the following workflow:
1. Train a Stable Diffusion LoRA model on a particular style.
2. Generate pairs of images of the character with the trained LoRA model.
3. Annotate the preferred image from each pair.
4. Apply Diffusion-DPO to the preference-annotated pairs to further fine-tune the LoRA model.
Note: The steps listed below are pretty rough. They are included primarily for reference for someone looking to resume this line of work in the future.
### 1. Train a style LoRA
```bash
invoke-train -c src/invoke_training/sample_configs/sd_lora_pokemon_1x8gb.yaml
```
### 2. Generate images
Prepare ~100 relevant prompts that will be used to generate training data with the freshly-trained LoRA model. Add the prompts to a `.txt` file - one prompt per line.
Example prompts:
```txt
a cute orange pokemon character with pointy ears
a drawing of a purple fish
a cartoon blob with a smile on its face
a drawing of a snail with big eyes
...
```
```bash
# Convert the LoRA checkpoint of interest to Kohya format.
# You will have to change the path timestamps in this example command.
# TODO(ryand): This manual conversion shouldn't be necessary.
python src/invoke_training/scripts/convert_sd_lora_to_kohya_format.py \
--src-ckpt-dir output/sd_lora_pokemon/1704824279.2765746/checkpoint_epoch-00000003/ \
--dst-ckpt-file output/sd_lora_pokemon/1704824279.2765746/checkpoint_epoch-00000003_kohya.safetensors
# Generate 2 pairs of images for each prompt.
invoke-generate-images \
-o output/pokemon_pairs \
-m runwayml/stable-diffusion-v1-5 \
-v fp16 \
-l output/sd_lora_pokemon/1704824279.2765746/checkpoint_epoch-00000003_kohya.safetensors \
--sd-version SD \
--prompt-file path/to/prompts.txt \
--set-size 2 \
--num-sets 2 \
--height 512 \
--width 512
```
### 3. Annotate the image pair preferences
Launch the gradio UI for selecting image pair preferences.
```bash
# Note: rank_images.py accepts a full training pipeline config, but only uses the dataset configuration.
python src/invoke_training/scripts/_experimental/rank_images.py -c src/invoke_training/sample_configs/_experimental/sd_dpo_lora_refinement_pokemon_1x24gb.yaml
```
After completing the pair annotations, click "Save Metadata" and move the resultant metadata file to your image data directory (e.g. `output/pokemon_pairs/metadata.jsonl`).
### 4. Run Diffusion-DPO
```bash
invoke-train -c src/invoke_training/sample_configs/_experimental/sd_dpo_lora_refinement_pokemon_1x24gb.yaml
```
================================================
FILE: docs/guides/stable_diffusion/gnome_lora_masks_sdxl.md
================================================
# LoRA with Masks - SDXL
This tutorial explains how to prepare masks for an image dataset and then use that dataset to train an SDXL LoRA model.
Masks can be used to weight regions of images in a dataset to control how much they contribute to the training process. In this tutorial we will use masks to train on a small dataset of images of Bruce the Gnome (4 images). With such a small dataset, there is a high risk of overfitting to the background elements from the images. We will use masks to avoid this problem ond focus only on the object of interest.
## 1 - Dataset Preparation
For this tutorial, we'll use a dataset consisting of 4 images of Bruce the Gnome:
| | |
| - | - |
|  |  |
|  |  |
This sample dataset is included in the invoke-training repo under [sample_data/bruce_the_gnome](https://github.com/invoke-ai/invoke-training/tree/main/sample_data/bruce_the_gnome).
## 2 - Generate Masks
Use the `generate_masks_for_jsonl_dataset.py` script to generate masks for your dataset based on a single prompt. In this case we are using the prompt `"a stuffed gnome"`:
```bash
python src/invoke_training/scripts/_experimental/masks/generate_masks_for_jsonl_dataset.py \
--in-jsonl sample_data/bruce_the_gnome/data.jsonl \
--out-jsonl sample_data/bruce_the_gnome/data_masks.jsonl \
--prompt "a stuffed gnome"
```
The mask generation script will produce the following outputs:
- A directory of generated masks: `sample_data/bruce_the_gnome/masks/`
- A new `.jsonl` file that references the mask images: `sample_data/bruce_the_gnome/data_masks.jsonl`
## 3 - Review the Generated Masks
Review the generated masks to make sure that the target regions were masked. You may need to adjust the prompt and re-generate the masks to achieve the desired result. Alternatively, you can edit the masks manually. The masks are simply single-channel grayscale images (0=background, 255=foreground).
Here are some examples of the masks that we just generated:
| | |
| - | - |
|  |  |
|  |  |
## 4 - Configuration
Below is the training configuration that we'll use for this tutorial.
Raw config file: [src/invoke_training/sample_configs/sdxl_lora_masks_gnome_1x24gb.yaml](https://github.com/invoke-ai/invoke-training/blob/main/src/invoke_training/sample_configs/sdxl_lora_masks_gnome_1x24gb.yaml).
```yaml title="sdxl_lora_masks_gnome_1x24gb.yaml"
--8<-- "src/invoke_training/sample_configs/sdxl_lora_masks_gnome_1x24gb.yaml"
```
Full documentation of all of the configuration options is here: [LoRA SDXL Config](../../reference/config/pipelines/sdxl_lora.md)
There are few things to note about this training config:
- We set `use_masks: True` in order to use the masks that we generated. This configuration is only compatible with datasets that have mask data.
- The `learning_rate`, `max_train_steps`, `save_every_n_steps`, and `validate_every_n_steps` are all _lower_ than typical for an SDXL LoRA training pipeline. The combination of masking with the small dataset size cause training to progress very quickly. These configuration fields were all adjusted accordingly to avoid overfitting.
## 5 - Start Training
Launch the training run.
```bash
# From inside the invoke-training/ source directory:
invoke-train -c src/invoke_training/sample_configs/sdxl_lora_masks_gnome_1x24gb.yaml
```
Training takes ~30 mins on an NVIDIA RTX 4090.
## 4 - Monitor
In a new terminal, launch Tensorboard to monitor the training run:
```bash
tensorboard --logdir output/
```
Access Tensorboard at [localhost:6006](http://localhost:6006) in your browser.
Sample images will be logged to Tensorboard so that you can see how the model is evolving.
Once training is complete, select the model checkpoint that produces the best visual results. For this tutorial, we'll use the checkpoint from step 300:

*Screenshot of the Tensorboard UI showing the validation images for epoch 300. The validation prompt was: "A stuffed gnome at the beach with a pina colada in its hand.".*
## 6 - Import into InvokeAI
If you haven't already, setup [InvokeAI](https://github.com/invoke-ai/InvokeAI) by following its documentation.
Import your trained LoRA model from the 'Models' tab.
Congratulations, you can now use your new Bruce-the-Gnome model! 🎉
================================================
FILE: docs/guides/stable_diffusion/robocats_finetune_sdxl.md
================================================
# Finetune - SDXL
This tutorial explains how to do a full finetune training run on a [Stable Diffusion XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) base model.
## 0 - Prerequisites
Full model finetuning is more compute-intensive than parameter-efficient finetuning alternatives (e.g. LoRA or Textual Inversion). This tutorial requires a minimum of 24GB of GPU VRAM.
## 1 - Dataset Preparation
For this tutorial, we will use a dataset consisting of 14 images of robocats. The images were auto-captioned. Here are some sample images from the dataset, including their captions:
| | |
| - | - |
|  |  |
| *A white robot with blue eyes and a yellow nose sits on a rock, gazing at the camera, with a pink tree and a white cat in the background.* | *A white cat with green eyes and a blue collar sits on a moss-covered rock in a forest, gazing directly at the camera.* |
## 2 - Configuration
Below is the training configuration that we'll use for this tutorial.
Raw config file: [src/invoke_training/sample_configs/sdxl_finetune_robocats_1x24gb.yaml](https://github.com/invoke-ai/invoke-training/blob/main/src/invoke_training/sample_configs/sdxl_finetune_robocats_1x24gb.yaml).
```yaml title="sdxl_finetune_robocats_1x24gb.yaml"
--8<-- "src/invoke_training/sample_configs/sdxl_finetune_robocats_1x24gb.yaml"
```
Full documentation of all of the configuration options is here: [Finetune SDXL Config](../../reference/config/pipelines/sdxl_finetune.md)
!!! note "`save_checkpoint_format`"
Note the `save_checkpoint_format` setting, as it is unique to full finetune training. For this tutorial, we have set `save_checkpoint_format: trained_only_diffusers`. This means that only the UNet model will be saved at each checkpoint, and it will be saved in diffusers format. This setting conserves disk space by not redundantly saving the non-trained weights. Before these UNet checkpoints can be used, they must either be merged into a full model, or extracted into a LoRA. Instructions for this follow later in this tutorial. A full explanation of the `save_checkpoint_format` options can be found here: [save_checkpoint_format][invoke_training.pipelines.stable_diffusion_xl.finetune.config.SdxlFinetuneConfig.save_checkpoint_format].
## 3 - Start Training
Launch the training run.
```bash
# From inside the invoke-training/ source directory:
invoke-train -c src/invoke_training/sample_configs/sdxl_finetune_robocats_1x24gb.yaml
```
Training takes ~45 mins on an NVIDIA RTX 4090.
## 4 - Monitor
In a new terminal, launch Tensorboard to monitor the training run:
```bash
tensorboard --logdir output/
```
Access Tensorboard at [localhost:6006](http://localhost:6006) in your browser.
Sample images will be logged to Tensorboard so that you can see how the model is evolving.
Once training is complete, select the model checkpoint that produces the best visual results.
## 5 - Prepare the trained model
Since we set `save_checkpoint_format: trained_only_diffusers`, our selected checkpoint only contains the UNet model weights. The checkpoint has the following directory structure:
```bash
output/robocats/sdxl_finetune/1715373799.3558652/checkpoints/checkpoint-epoch_00000500-step_00002000/
└── unet
├── config.json
└── diffusion_pytorch_model.safetensors
```
Before we can use this trained model, we must do one of the following:
- Prepare a full diffusers checkpoint with the new UNet weights.
- Extract the difference between the trained UNet and the original UNet into a LoRA model.
### Prepare a full model
If we want to use our finetuned UNet model, we must first package it into a format supported by applications like InvokeAI.
In this section we will assume that we have a full SDXL base model in diffusers format. It should have a directory structure like the one shown before. We simply need to replace the `unet/` directory with the one from our selected training checkpoint:
```bash
stable-diffusion-xl-base-1.0
├── model_index.json
├── scheduler
│ └── scheduler_config.json
├── text_encoder
│ ├── config.json
│ └── model.fp16.safetensors
├── text_encoder_2
│ ├── config.json
│ └── model.fp16.safetensors
├── tokenizer
│ ├── merges.txt
│ ├── special_tokens_map.json
│ ├── tokenizer_config.json
│ └── vocab.json
├── tokenizer_2
│ ├── merges.txt
│ ├── special_tokens_map.json
│ ├── tokenizer_config.json
│ └── vocab.json
├── unet # <-- Replace this directory with the trained checkpoint.
│ ├── config.json
│ └── diffusion_pytorch_model.fp16.safetensors
├── vae
│ ├── config.json
│ └── diffusion_pytorch_model.fp16.safetensors
└── vae_1_0
└── diffusion_pytorch_model.fp16.safetensors
```
!!! note "diffusers variants (e.g. 'fp16')"
In this example, notice that the `*.safetensors` files contain `.fp16.` in their filenames. Hugging Face refers to this identifier as a "variant". It is used to select between multiple model variants in their model hub.
In this case, we should add the `.fp16.` variant tag to our finetuned UNet for consistency with the rest of the model. Since we set `save_dtype: float16` in our training config, the `fp16` tag accurately represents the precision of our UNet model file.
### Extract a LoRA model
An alternative to using the finetuned UNet model directly is to compare it against the original and extract the difference as a LoRA model. The resultant LoRA has a much smaller file size and can be applied to any base model. But, the LoRA model is a *lossy* representation of the difference, so some quality degradation is expected.
To extract a LoRA model, run the following command:
```bash
python src/invoke_training/model_merge/scripts/extract_lora_from_model_diff.py \
--model-type SDXL \
--model-orig path/to/stable-diffusion-xl-base-1.0 \
--model-tuned output/robocats/sdxl_finetune/1715373799.3558652/checkpoints/checkpoint-epoch_00000500-step_00002000 \
--save-to robocats_lora_step_2000.safetensors \
--lora-rank 32
```
## 6 - Import into InvokeAI
If you haven't already, setup [InvokeAI](https://github.com/invoke-ai/InvokeAI) by following its documentation.
Import your finetuned diffusers model or your extracted LoRA from the 'Models' tab.
Congratulations, you can now use your new robocat model! 🎉
## 7 - Comparison: Finetune vs. LoRA Extraction
As noted earlier, the LoRA extraction process is lossy for a number of reasons.
Below, we compare images generated with the same seed and prompt for 3 different model configurations.
Prompt: *In robocat style, a robotic lion in the jungle.*
| SDXL Base 1.0 | w/ Finetuned UNet | w/ Extracted LoRA |
| - | - | - |
|  |  | 
================================================
FILE: docs/guides/stable_diffusion/textual_inversion_sdxl.md
================================================
# Textual Inversion - SDXL
This tutorial walks through a [Textual Inversion](https://arxiv.org/abs/2208.01618) training run with a [Stable Diffusion XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) base model.
## 1 - Dataset
For this tutorial, we'll use a dataset consisting of 4 images of Bruce the Gnome:
| | |
| - | - |
|  |  |
|  |  |
This sample dataset is included in the invoke-training repo under [sample_data/bruce_the_gnome](https://github.com/invoke-ai/invoke-training/tree/main/sample_data/bruce_the_gnome).
Here are a few tips for preparing a Textual Inversion dataset:
- Aim for 4 to 50 images of your concept (object / style). The optimal number depends on many factors, and can be much higher than this for some use cases.
- Vary all of the image features that you *don't* want your TI embedding to contain (e.g. background, pose, lighting, etc.).
## 2 - Configuration
Below is the training configuration that we'll use for this tutorial.
Raw config file: [src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml](https://github.com/invoke-ai/invoke-training/blob/main/src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml).
Full config reference docs: [Textual Inversion SDXL Config](../../reference/config/pipelines/sdxl_textual_inversion.md)
```yaml title="sdxl_textual_inversion_gnome_1x24gb.yaml"
--8<-- "src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml"
```
## 3 - Start Training
[Install invoke-training](../../get-started/installation.md), if you haven't already.
Launch the Textual Inversion training pipeline:
```bash
# From inside the invoke-training/ source directory:
invoke-train -c src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml
```
Training takes ~40 mins on an NVIDIA RTX 4090.
## 4 - Monitor
In a new terminal, launch Tensorboard to monitor the training run:
```bash
tensorboard --logdir output/
```
Access Tensorboard at [localhost:6006](http://localhost:6006) in your browser.
Sample images will be logged to Tensorboard so that you can see how the Textual Inversion embedding is evolving.
Once training is complete, select the epoch that produces the best visual results.
For this tutorial, we'll choose epoch 500:

*Screenshot of the Tensorboard UI showing the validation images for epoch 500.*
## 5 - Transfer to InvokeAI
If you haven't already, setup [InvokeAI](https://github.com/invoke-ai/InvokeAI) by following its documentation.
Copy the selected TI embedding into your `${INVOKEAI_ROOT}/autoimport/embedding/` directory. For example:
```bash
cp output/sdxl_ti_bruce_the_gnome/1702587511.2273068/checkpoint_epoch-00000500.safetensors ${INVOKEAI_ROOT}/autoimport/embedding/bruce_the_gnome.safetensors
```
Note that we renamed the file to `bruce_the_gnome.safetensors`. You can choose any file name, but this will become the token used to reference your embedding. So, in our case, we can refer to our new embedding by including `<bruce_the_gnome>` in our prompts.
Launch Invoke AI and you can now use your new `bruce_the_gnome` TI embedding! 🎉

*Example image generated with the prompt "`a photo of <bruce_the_gnome> at the park`".*
================================================
FILE: docs/index.md
================================================
# invoke-training
A library for training custom Stable Diffusion models (fine-tuning, LoRA training, textual inversion, etc.) that can be used in [InvokeAI](https://github.com/invoke-ai/InvokeAI).
## Documentation
The documentation is organized as follows:
- [Get Started](get-started/installation.md): Install `invoke-training` and run your first training pipeline.
- [Guides](guides/dataset_formats.md): Full tutorials for running popular training pipelines.
- [Config Reference](reference/config/index.md): Reference documentation for all supported training configuration options.
- [Contributing](contributing/development_environment.md): Information for `invoke-training` developers.
================================================
FILE: docs/reference/config/index.md
================================================
# Config Reference
This section contains reference documentation for the `invoke-training` configuration schema (i.e. documentation for all of the supported training options).
This documentation uses python typing semantics to define the configuration schema. Typically the configuration for a training run is specified in a YAML file and then parse against this schema.
================================================
FILE: docs/reference/config/pipelines/sd_lora.md
================================================
# `SdLoraConfig`
<!-- To control the member order, we first list out the members whose order we care about, then we list the rest. -->
::: invoke_training.pipelines.stable_diffusion.lora.config.SdLoraConfig
options:
members:
- type
<!-- Note that we always hide "model_config", as it should not be set by the user. -->
::: invoke_training.pipelines.stable_diffusion.lora.config.SdLoraConfig
options:
filters:
- "!^model_config"
- "!^type"
================================================
FILE: docs/reference/config/pipelines/sd_textual_inversion.md
================================================
# `SdTextualInversionConfig`
<!-- To control the member order, we first list out the members whose order we care about, then we list the rest. -->
::: invoke_training.pipelines.stable_diffusion.textual_inversion.config.SdTextualInversionConfig
options:
members:
- type
<!-- Note that we always hide "model_config", as it should not be set by the user. -->
::: invoke_training.pipelines.stable_diffusion.textual_inversion.config.SdTextualInversionConfig
options:
filters:
- "!^model_config"
- "!^type"
================================================
FILE: docs/reference/config/pipelines/sdxl_finetune.md
================================================
# `SdxlFinetuneConfig`
<!-- To control the member order, we first list out the members whose order we care about, then we list the rest. -->
::: invoke_training.pipelines.stable_diffusion_xl.finetune.config.SdxlFinetuneConfig
options:
members:
- type
<!-- Note that we always hide "model_config", as it should not be set by the user. -->
::: invoke_training.pipelines.stable_diffusion_xl.finetune.config.SdxlFinetuneConfig
options:
filters:
- "!^model_config"
- "!^type"
================================================
FILE: docs/reference/config/pipelines/sdxl_lora.md
================================================
# `SdxlLoraConfig`
<!-- To control the member order, we first list out the members whose order we care about, then we list the rest. -->
::: invoke_training.pipelines.stable_diffusion_xl.lora.config.SdxlLoraConfig
options:
members:
- type
<!-- Note that we always hide "model_config", as it should not be set by the user. -->
::: invoke_training.pipelines.stable_diffusion_xl.lora.config.SdxlLoraConfig
options:
filters:
- "!^model_config"
- "!^type"
================================================
FILE: docs/reference/config/pipelines/sdxl_lora_and_textual_inversion.md
================================================
# `SdxlLoraAndTextualInversionConfig`
<!-- To control the member order, we first list out the members whose order we care about, then we list the rest. -->
::: invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config.SdxlLoraAndTextualInversionConfig
options:
members:
- type
<!-- Note that we always hide "model_config", as it should not be set by the user. -->
::: invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config.SdxlLoraAndTextualInversionConfig
options:
filters:
- "!^model_config"
- "!^type"
================================================
FILE: docs/reference/config/pipelines/sdxl_textual_inversion.md
================================================
# `SdxlTextualInversionConfig`
Below is a sample yaml config file for Textual Inversion SDXL training ([raw file](https://github.com/invoke-ai/invoke-training/blob/main/src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml)). All of the configuration fields are explained in detail on this page.
```yaml title="sdxl_textual_inversion_gnome_1x24gb.yaml"
--8<-- "src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml"
```
<!-- To control the member order, we first list out the members whose order we care about, then we list the rest. -->
::: invoke_training.pipelines.stable_diffusion_xl.textual_inversion.config.SdxlTextualInversionConfig
options:
members:
- type
<!-- Note that we always hide "model_config", as it should not be set by the user. -->
::: invoke_training.pipelines.stable_diffusion_xl.textual_inversion.config.SdxlTextualInversionConfig
options:
filters:
- "!^model_config"
- "!^type"
================================================
FILE: docs/reference/config/shared/data/data_loader_config.md
================================================
::: invoke_training.config.data.data_loader_config
options:
filters:
- "!^model_config"
================================================
FILE: docs/reference/config/shared/data/dataset_config.md
================================================
::: invoke_training.config.data.dataset_config
options:
filters:
- "!^model_config"
================================================
FILE: docs/reference/config/shared/optimizer_config.md
================================================
::: invoke_training.config.optimizer.optimizer_config
options:
filters:
- "!^model_config"
================================================
FILE: docs/templates/python/material/labels.html
================================================
<!--
This file is intentionally empty. It overrides the default contents of
https://github.com/mkdocstrings/python/blob/master/src/mkdocstrings_handlers/python/templates/material/labels.html
to hide labels (class-attribute, instance-attribute, etc.)
-->
================================================
FILE: mkdocs.yml
================================================
site_name: invoke-training
site_url: https://invoke-ai.github.io/invoke-training/
repo_name: invoke-ai/invoke-training
repo_url: https://github.com/invoke-ai/invoke-training
theme:
name: material
features:
- navigation.tabs
- navigation.indexes
- navigation.sections
- content.code.copy
markdown_extensions:
- admonition
- sane_lists
- pymdownx.highlight:
anchor_linenums: true
line_spans: __span
pygments_lang_class: true
- pymdownx.inlinehilite
- pymdownx.snippets
- pymdownx.superfences
nav:
- Welcome: index.md
- Get Started:
- get-started/installation.md
- get-started/quick-start.md
- Guides:
- Dataset Formats: guides/dataset_formats.md
- Model Merging: guides/model_merge.md
- Stable Diffusion Training:
- guides/stable_diffusion/robocats_finetune_sdxl.md
- guides/stable_diffusion/gnome_lora_masks_sdxl.md
- guides/stable_diffusion/textual_inversion_sdxl.md
- guides/stable_diffusion/dpo_lora_sd.md
- YAML Config Reference:
- reference/config/index.md
- pipelines:
- SD LoRA Config: reference/config/pipelines/sd_lora.md
- SD Textual Inversion Config: reference/config/pipelines/sd_textual_inversion.md
- SDXL LoRA Config: reference/config/pipelines/sdxl_lora.md
- SDXL Textual Inversion Config: reference/config/pipelines/sdxl_textual_inversion.md
- SDXL LoRA and Textual Inversion Config: reference/config/pipelines/sdxl_lora_and_textual_inversion.md
- SDXL Finetune Config: reference/config/pipelines/sdxl_finetune.md
- shared:
- data_loader_config: reference/config/shared/data/data_loader_config.md
- dataset_config: reference/config/shared/data/dataset_config.md
- optimizer_config: reference/config/shared/optimizer_config.md
- Contributing:
- contributing/development_environment.md
- contributing/directory_structure.md
- contributing/tests.md
- contributing/documentation.md
plugins:
- search
- mkdocstrings:
default_handler: python
custom_templates: docs/templates
handlers:
python:
options:
show_root_heading: false
show_root_toc_entry: false
show_bases: false
show_source: false
show_if_no_docstring: true
inherited_members: true
annotations_path: brief
separate_signature: true
show_signature_annotations: true
members_order: source
================================================
FILE: pyproject.toml
================================================
[build-system]
requires = ["setuptools>=65.5", "pip>=22.3"]
build-backend = "setuptools.build_meta"
[project]
name = "invoke-training"
version = "0.0.1"
authors = [{ name = "The Invoke AI Team", email = "ryan@invoke.ai" }]
description = "A library for Stable Diffusion model training."
readme = "README.md"
requires-python = ">=3.10"
license = { text = "Apache-2.0" }
classifiers = [
"Programming Language :: Python :: 3",
"Operating System :: OS Independent",
]
dependencies = [
"accelerate",
"datasets~=2.14.3",
"diffusers[torch]",
"einops",
"fastapi",
"gradio",
"invokeai>=5.10.0a1",
"numpy<2.0.0",
"omegaconf",
"peft~=0.11.1",
"pillow",
"prodigyopt",
"pydantic",
"pyyaml",
"safetensors",
"tensorboard",
"torch",
"torchvision",
"tqdm",
"transformers",
"uvicorn[standard]",
]
[project.optional-dependencies]
"xformers" = ["xformers>=0.0.28.post1; sys_platform!='darwin'"]
"bitsandbytes" = ["bitsandbytes>=0.43.1; sys_platform!='darwin'"]
"test" = [
"mkdocs",
"mkdocs-material",
"mkdocstrings[python]",
"pre-commit~=3.3.3",
"pytest~=7.4.0",
"ruff~=0.11.2",
"ruff-lsp",
]
[project.scripts]
"invoke-train" = "invoke_training.scripts.invoke_train:main"
"invoke-train-ui" = "invoke_training.scripts.invoke_train_ui:main"
"invoke-generate-images" = "invoke_training.scripts.invoke_generate_images:main"
"invoke-visualize-data-loading" = "invoke_training.scripts.invoke_visualize_data_loading:main"
[project.urls]
"Homepage" = "https://github.com/invoke-ai/invoke-training"
"Discord" = "https://discord.gg/ZmtBAhwWhy"
[tool.setuptools.package-data]
"invoke_training.assets" = ["*.png"]
"invoke_training.sample_configs" = ["**/*.yaml"]
"invoke_training.ui" = ["*.html"]
[tool.ruff]
src = ["src"]
lint.select = ["E", "F", "W", "C9", "N8", "I"]
target-version = "py39"
line-length = 120
[tool.pytest.ini_options]
addopts = "--strict-markers"
markers = [
"cuda: marks tests that require a CUDA GPU",
"loads_model: marks tests that require a model (or data) from the HF hub",
]
================================================
FILE: sample_data/bruce_the_gnome/data.jsonl
================================================
{"image": "001.png", "text": "A stuffed gnome sits on a wooden floor, facing right with a gray couch in the background."}
{"image": "002.png", "text": "A stuffed gnome stands on a black tiled floor, with a silver refrigerator and white wall in the background."}
{"image": "004.png", "text": "A stuffed gnome sits on a white marble floor, photorealistic."}
{"image": "003.png", "text": "A stuffed gnome sits on a gray tiled floor, facing the camera."}
================================================
FILE: src/invoke_training/__init__.py
================================================
================================================
FILE: src/invoke_training/_shared/__init__.py
================================================
================================================
FILE: src/invoke_training/_shared/accelerator/__init__.py
================================================
================================================
FILE: src/invoke_training/_shared/accelerator/accelerator_utils.py
================================================
import logging
import os
from typing import Literal
import datasets
import diffusers
import torch
import transformers
from accelerate import Accelerator
from accelerate.logging import MultiProcessAdapter, get_logger
from accelerate.utils import ProjectConfiguration
def initialize_accelerator(
out_dir: str, gradient_accumulation_steps: int, mixed_precision: str, log_with: str
) -> Accelerator:
"""Configure Hugging Face accelerate and return an Accelerator.
Args:
out_dir (str): The output directory where results will be written.
gradient_accumulation_steps (int): Forwarded to accelerat.Accelerator(...).
mixed_precision (str): Forwarded to accelerate.Accelerator(...).
log_with (str): Forwarded to accelerat.Accelerator(...)
Returns:
Accelerator
"""
accelerator_project_config = ProjectConfiguration(
project_dir=out_dir,
logging_dir=os.path.join(out_dir, "logs"),
)
return Accelerator(
project_config=accelerator_project_config,
gradient_accumulation_steps=gradient_accumulation_steps,
mixed_precision=mixed_precision,
log_with=log_with,
)
def initialize_logging(logger_name: str, accelerator: Accelerator) -> MultiProcessAdapter:
"""Configure logging.
Returns an accelerate logger with multi-process logging support. Logging is configured to be more verbose on the
main process. Non-main processes only log at error level for Hugging Face libraries (datasets, transformers,
diffusers).
Args:
accelerator (Accelerator): The Accelerator to configure.
Returns:
MultiProcessAdapter: _description_
"""
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
# Only log errors from non-main processes.
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
return get_logger(logger_name)
def get_mixed_precision_dtype(accelerator: Accelerator):
"""Extract torch.dtype from Accelerator config.
Args:
accelerator (Accelerator): The Hugging Face Accelerator.
Raises:
NotImplementedError: If the accelerator's mixed_precision configuration is not recognized.
Returns:
torch.dtype: The weight type inferred from the accelerator mixed_precision configuration.
"""
weight_dtype: torch.dtype = torch.float32
if accelerator.mixed_precision is None or accelerator.mixed_precision == "no":
weight_dtype = torch.float32
elif accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
else:
raise NotImplementedError(f"mixed_precision mode '{accelerator.mixed_precision}' is not yet supported.")
return weight_dtype
def get_dtype_from_str(dtype_str: Literal["float16", "bfloat16", "float32"]) -> torch.dtype:
if dtype_str == "float16":
return torch.float16
elif dtype_str == "bfloat16":
return torch.bfloat16
elif dtype_str == "float32":
return torch.float32
else:
raise ValueError(f"Unsupported dtype: {dtype_str}")
================================================
FILE: src/invoke_training/_shared/checkpoints/__init__.py
================================================
================================================
FILE: src/invoke_training/_shared/checkpoints/checkpoint_tracker.py
================================================
import os
import shutil
import typing
class CheckpointTracker:
"""A utility class for managing checkpoint paths.
Manages checkpoint paths of the following forms:
- Checkpoint directories: `{base_dir}/{prefix}-epoch_{num_epochs}-step_{num_steps}`
- Checkpoint files: `{base_dir}/{prefix}-epoch_{num_epochs}-step_{num_steps}{extension}`
"""
def __init__(
self,
base_dir: str,
prefix: str,
extension: typing.Optional[str] = None,
max_checkpoints: typing.Optional[int] = None,
index_padding: int = 8,
):
"""Initialize a CheckpointTracker.
Args:
base_dir (str): The base checkpoint directory.
prefix (str): A prefix applied to every checkpoint.
extension (str, optional): If set, this is the file extension that will be applied to all checkpoints
(usually one of ".pt", ".ckpt", or ".safetensors"). If None, then it will be assumed that we are
managing checkpoint directories rather than files.
max_checkpoints (typing.Optional[int], optional): The maximum number of checkpoints that should exist in
base_dir.
index_padding (int, optional): The length of the zero-padded epoch/step counts in the generated checkpoint
names. E.g. index_padding=8 would produce checkpoint paths like
"base_dir/prefix-epoch_00000001-step_00000001.ckpt".
Raises:
ValueError: If extension is provided, but it doesn not start with a '.'.
"""
if extension is not None and not extension.startswith("."):
raise ValueError(f"extension='{extension}' must start with a '.'.")
self._base_dir = base_dir
self._prefix = prefix
self._extension = extension
self._max_checkpoints = max_checkpoints
self._index_padding = index_padding
def prune(self, buffer_num: int = 1) -> int:
"""Delete checkpoint files and directories so that there are at most `max_checkpoints - buffer_num` checkpoints
remaining. The checkpoints with the lowest step counts will be deleted.
Args:
buffer_num (int, optional): The number below `max_checkpoints` to 'free-up'.
Returns:
int: The number of checkpoints deleted.
"""
if self._max_checkpoints is None:
return 0
checkpoints = os.listdir(self._base_dir)
checkpoints = [p for p in checkpoints if p.startswith(self._prefix)]
checkpoints = sorted(
checkpoints,
key=lambda x: int(os.path.splitext(x)[0].split("-step_")[-1]),
)
num_to_remove = len(checkpoints) - (self._max_checkpoints - buffer_num)
if num_to_remove > 0:
checkpoints_to_remove = checkpoints[:num_to_remove]
for checkpoint_to_remove in checkpoints_to_remove:
checkpoint_to_remove = os.path.join(self._base_dir, checkpoint_to_remove)
if os.path.isfile(checkpoint_to_remove):
# Delete checkpoint file.
os.remove(checkpoint_to_remove)
else:
# Delete checkpoint directory.
shutil.rmtree(checkpoint_to_remove)
return max(0, num_to_remove)
def get_path(self, epoch: int, step: int) -> str:
"""Get the checkpoint path for index `idx`.
Args:
epoch (int): The number of completed epochs.
step (int): The number of completed training steps.
Returns:
str: The checkpoint path.
"""
suffix = self._extension or ""
return os.path.join(
self._base_dir,
f"{self._prefix.strip()}-epoch_{epoch:0>{self._index_padding}}-step_{step:0>{self._index_padding}}{suffix}",
)
================================================
FILE: src/invoke_training/_shared/checkpoints/lora_checkpoint_utils.py
================================================
from pathlib import Path
import peft
import torch
def save_multi_model_peft_checkpoint(checkpoint_dir: Path | str, models: dict[str, peft.PeftModel]):
"""Save a dict of PeftModels to a checkpoint directory.
The `models` dict keys are used as the subdirectories for each individual model.
`load_multi_model_peft_checkpoint(...)` can be used to load the resultant checkpoint.
"""
checkpoint_dir = Path(checkpoint_dir)
for model_key, peft_model in models.items():
assert isinstance(peft_model, peft.PeftModel)
# HACK(ryand): PeftModel.save_pretrained(...) expects the config to have a "_name_or_path" entry. For now, we
# set this to None here. This should be fixed upstream in PEFT.
if (
hasattr(peft_model, "config")
and isinstance(peft_model.config, dict)
and "_name_or_path" not in peft_model.config
):
peft_model.config["_name_or_path"] = None
peft_model.save_pretrained(str(checkpoint_dir / model_key))
def load_multi_model_peft_checkpoint(
checkpoint_dir: Path | str,
models: dict[str, torch.nn.Module],
is_trainable: bool = False,
raise_if_subdir_missing: bool = True,
) -> dict[str, torch.nn.Module]:
"""Load a multi-model PEFT checkpoint that was saved with `save_multi_model_peft_checkpoint(...)`."""
checkpoint_dir = Path(checkpoint_dir)
assert checkpoint_dir.exists()
out_models = {}
for model_key, model in models.items():
dir_path: Path = checkpoint_dir / model_key
if dir_path.exists():
out_models[model_key] = peft.PeftModel.from_pretrained(model, dir_path, is_trainable=is_trainable)
else:
if raise_if_subdir_missing:
raise ValueError(f"'{dir_path}' does not exist.")
else:
# Pass through the model unchanged.
out_models[model_key] = model
return out_models
# This implementation is based on
# https://github.com/huggingface/peft/blob/8665e2b5719faa4e4b91749ddec09442927b53e0/examples/lora_dreambooth/convert_peft_sd_lora_to_kohya_ss.py#L20
def _convert_peft_state_dict_to_kohya_state_dict(
lora_config: peft.LoraConfig,
peft_state_dict: dict[str, torch.Tensor],
prefix: str,
dtype: torch.dtype,
) -> dict[str, torch.Tensor]:
kohya_ss_state_dict = {}
for peft_key, weight in peft_state_dict.items():
kohya_key = peft_key.replace("base_model.model", prefix)
kohya_key = kohya_key.replace("lora_A", "lora_down")
kohya_key = kohya_key.replace("lora_B", "lora_up")
kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
kohya_ss_state_dict[kohya_key] = weight.to(dtype)
# Set alpha parameter
if "lora_down" in kohya_key:
alpha_key = f"{kohya_key.split('.')[0]}.alpha"
kohya_ss_state_dict[alpha_key] = torch.tensor(lora_config.lora_alpha).to(dtype)
return kohya_ss_state_dict
def _convert_peft_models_to_kohya_state_dict(
kohya_prefixes: list[str], models: list[peft.PeftModel]
) -> dict[str, torch.Tensor]:
kohya_state_dict = {}
default_adapter_name = "default"
for kohya_prefix, peft_model in zip(kohya_prefixes, models, strict=True):
lora_config = peft_model.peft_config[default_adapter_name]
assert isinstance(lora_config, peft.LoraConfig)
peft_state_dict = peft.get_peft_model_state_dict(peft_model, adapter_name=default_adapter_name)
kohya_state_dict.update(
_convert_peft_state_dict_to_kohya_state_dict(
lora_config=lora_config,
peft_state_dict=peft_state_dict,
prefix=kohya_prefix,
dtype=torch.float32,
)
)
return kohya_state_dict
================================================
FILE: src/invoke_training/_shared/checkpoints/serialization.py
================================================
import typing
from pathlib import Path
import safetensors.torch
import torch
def save_state_dict(state_dict: typing.Dict[str, torch.Tensor], out_file: typing.Union[Path, str]):
"""Save a state_dict to a file.
Both safetensors and torch formats are supported. The format is inferred from the `out_file` extension.
Supported extensions:
- ".ckpt" -> torch
- ".pt" -> torch
- ".safetensors -> safetensors
Args:
state_dict (typing.Dict[str, torch.Tensor]): The state_dict to save.
out_file (Path | str): The output file to save to.
Raises:
ValueError: If the `out_file` has an unsupported file extension.
"""
out_file = Path(out_file)
if out_file.suffix == ".ckpt" or out_file.suffix == ".pt":
torch.save(state_dict, out_file)
elif out_file.suffix == ".safetensors":
safetensors.torch.save_file(state_dict, out_file)
else:
raise ValueError(f"Unsupported file extension: '{out_file.suffix}'.")
def load_state_dict(in_file: typing.Union[Path, str]) -> typing.Dict[str, torch.Tensor]:
"""Load a state_dict from a file.
Both safetensors and torch formats are supported. The format is inferred from the `in_file` extension.
Supported extensions:
- ".ckpt" -> torch
- ".pt" -> torch
- ".safetensors -> safetensors
Args:
in_file (Path | str): The input file to load from.
Raises:
ValueError: If the `in_file` has an unsupported file extension.
Returns:
typing.Dict[str, torch.Tensor]: The loaded state_dict.
"""
in_file = Path(in_file)
if in_file.suffix == ".ckpt" or in_file.suffix == ".pt":
return torch.load(in_file)
elif in_file.suffix == ".safetensors":
return safetensors.torch.load_file(in_file)
else:
raise ValueError(f"Unsupported file extension: '{in_file.suffix}'.")
================================================
FILE: src/invoke_training/_shared/data/ARCHITECTURE.md
================================================
# Dataset Architecture
Dataset handling is split into 3 layers of abstraction: Datasets, Transforms, and DataLoaders. Each is explained in more detail below.
## Datasets
Datasets implement the [torch.utils.data.Dataset](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#creating-a-custom-dataset-for-your-files) interface.
Most dataset classes act as an abstraction over a specific dataset format.
## Transforms
Transforms are functions applied to data loaded by Datasets. For example, the `SDImageTransform` implements image augmentations for Stable Diffusion training.
Transforms are kept separate from the underlying datasets for several reasons:
- It is easier to write tests for isolated transforms.
- Modular transforms can often be re-used for multiple base datasets.
- Modular transforms make it easy to customize datasets for different situations. For example, you may want to wrap a dataset with one set of transforms initially to populate a cache, and then with a different set of transforms to read from the cache.
Transforms are applied to a dataset via the `TransformDataset` class.
## DataLoaders
The dataset classes (with composed transforms) are wrapped in a `torch.utils.data.DataLoader` that handles batch collation, multi-processing, etc.
================================================
FILE: src/invoke_training/_shared/data/__init__.py
================================================
================================================
FILE: src/invoke_training/_shared/data/data_loaders/__init__.py
================================================
================================================
FILE: src/invoke_training/_shared/data/data_loaders/dreambooth_sd_dataloader.py
================================================
import typing
from torch.utils.data import ConcatDataset, DataLoader
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from invoke_training._shared.data.data_loaders.image_caption_sd_dataloader import (
build_aspect_ratio_bucket_manager,
sd_image_caption_collate_fn,
)
from invoke_training._shared.data.datasets.image_dir_dataset import ImageDirDataset
from invoke_training._shared.data.datasets.transform_dataset import TransformDataset
from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import AspectRatioBucketBatchSampler
from invoke_training._shared.data.samplers.batch_offset_sampler import BatchOffsetSampler
from invoke_training._shared.data.samplers.concat_sampler import ConcatSampler
from invoke_training._shared.data.samplers.interleaved_sampler import InterleavedSampler
from invoke_training._shared.data.samplers.offset_sampler import OffsetSampler
from invoke_training._shared.data.transforms.constant_field_transform import ConstantFieldTransform
from invoke_training._shared.data.transforms.drop_field_transform import DropFieldTransform
from invoke_training._shared.data.transforms.load_cache_transform import LoadCacheTransform
from invoke_training._shared.data.transforms.sd_image_transform import SDImageTransform
from invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache
from invoke_training.config.data.data_loader_config import DreamboothSDDataLoaderConfig
def build_dreambooth_sd_dataloader(
config: DreamboothSDDataLoaderConfig,
batch_size: int,
text_encoder_output_cache_dir: typing.Optional[str] = None,
text_encoder_cache_field_to_output_field: typing.Optional[dict[str, str]] = None,
vae_output_cache_dir: typing.Optional[str] = None,
shuffle: bool = True,
sequential_batching: bool = False,
) -> DataLoader:
"""Construct a DataLoader for a DreamBooth dataset for Stable Diffusion XL.
Args:
config (DreamboothSDDataLoaderConfig):
batch_size (int):
text_encoder_output_cache_dir (str, optional): The directory where text encoder outputs are cached and should be
loaded from.
vae_output_cache_dir (str, optional): The directory where VAE outputs are cached and should be loaded from. If
set, then the image augmentation transforms will be skipped, and the image will not be copied to VRAM.
shuffle (bool, optional): Whether to shuffle the dataset order.
sequential_batching (bool, optional): If True, the internal dataset will be processed sequentially rather than
interleaving class and instance examples. This is intended to be used when processing the entire dataset for
caching purposes. Defaults to False.
Returns:
DataLoader
"""
# Prepare instance dataset.
base_instance_dataset = ImageDirDataset(
config.instance_dataset.dataset_dir,
id_prefix="instance_",
keep_in_memory=config.instance_dataset.keep_in_memory,
)
instance_dataset = TransformDataset(
base_instance_dataset,
[
ConstantFieldTransform("caption", config.instance_caption),
ConstantFieldTransform("loss_weight", 1.0),
],
)
datasets = [instance_dataset]
# Prepare class dataset.
base_class_dataset = None
class_dataset = None
if config.class_dataset is not None:
base_class_dataset = ImageDirDataset(
config.class_dataset.dataset_dir, id_prefix="class_", keep_in_memory=config.class_dataset.keep_in_memory
)
class_dataset = TransformDataset(
base_class_dataset,
[
ConstantFieldTransform("caption", config.class_caption),
ConstantFieldTransform("loss_weight", config.class_data_loss_weight),
],
)
datasets.append(class_dataset)
# Merge instance dataset and class dataset.
merged_dataset = ConcatDataset(datasets)
# Initialize either the fixed target resolution or aspect ratio buckets.
target_resolution = None
aspect_ratio_bucket_manager = None
instance_sampler = None
class_sampler = None
if config.aspect_ratio_buckets is None:
target_resolution = config.resolution
# TODO(ryand): Provide a seeded generator.
instance_sampler = RandomSampler(instance_dataset) if shuffle else SequentialSampler(instance_dataset)
if base_class_dataset is not None:
class_sampler = RandomSampler(class_dataset) if shuffle else SequentialSampler(class_dataset)
class_sampler = OffsetSampler(class_sampler, offset=len(base_instance_dataset))
else:
aspect_ratio_bucket_manager = build_aspect_ratio_bucket_manager(config=config.aspect_ratio_buckets)
# TODO(ryand): Drill-down the seed parameter rather than hard-coding to 0 here.
instance_sampler = AspectRatioBucketBatchSampler.from_image_sizes(
bucket_manager=aspect_ratio_bucket_manager,
image_sizes=base_instance_dataset.get_image_dimensions(),
batch_size=batch_size,
shuffle=shuffle,
seed=0,
)
if base_class_dataset is not None:
class_sampler = AspectRatioBucketBatchSampler.from_image_sizes(
bucket_manager=aspect_ratio_bucket_manager,
image_sizes=base_class_dataset.get_image_dimensions(),
batch_size=batch_size,
shuffle=shuffle,
seed=0,
)
class_sampler = BatchOffsetSampler(class_sampler, offset=len(base_instance_dataset))
# Add transforms to the merged dataset.
all_transforms = []
if vae_output_cache_dir is None:
all_transforms.append(
SDImageTransform(
image_field_names=["image"],
fields_to_normalize_to_range_minus_one_to_one=["image"],
resolution=target_resolution,
aspect_ratio_bucket_manager=aspect_ratio_bucket_manager,
center_crop=config.center_crop,
random_flip=config.random_flip,
)
)
else:
vae_cache = TensorDiskCache(vae_output_cache_dir)
all_transforms.append(
LoadCacheTransform(
cache=vae_cache,
cache_key_field="id",
cache_field_to_output_field={
"vae_output": "vae_output",
"original_size_hw": "original_size_hw",
"crop_top_left_yx": "crop_top_left_yx",
},
)
)
# We drop the image to avoid having to either convert from PIL, or handle PIL batch collation.
all_transforms.append(DropFieldTransform("image"))
if text_encoder_output_cache_dir is not None:
assert text_encoder_cache_field_to_output_field is not None
text_encoder_cache = TensorDiskCache(text_encoder_output_cache_dir)
all_transforms.append(
LoadCacheTransform(
cache=text_encoder_cache,
cache_key_field="id",
cache_field_to_output_field=text_encoder_cache_field_to_output_field,
)
)
merged_dataset = TransformDataset(merged_dataset, all_transforms)
# Choose between sequential vs. interleaved merging of the instance and class samplers.
# Sequential sampling is typically used to populate a cache, because it guarantees that all examples will be
# included in an epoch.
samplers = [instance_sampler]
if class_sampler is not None:
samplers.append(class_sampler)
if sequential_batching:
sampler = ConcatSampler(samplers)
else:
sampler = InterleavedSampler(samplers)
if config.aspect_ratio_buckets is None:
return DataLoader(
merged_dataset,
sampler=sampler,
collate_fn=sd_image_caption_collate_fn,
batch_size=batch_size,
num_workers=config.dataloader_num_workers,
)
else:
# If config.aspect_ratio_buckets is not None, then we are using a batch sampler.
return DataLoader(
merged_dataset,
batch_sampler=sampler,
collate_fn=sd_image_caption_collate_fn,
num_workers=config.dataloader_num_workers,
)
================================================
FILE: src/invoke_training/_shared/data/data_loaders/image_caption_flux_dataloader.py
================================================
import typing
from torch.utils.data import DataLoader
from invoke_training._shared.data.data_loaders.image_caption_sd_dataloader import (
build_aspect_ratio_bucket_manager,
)
from invoke_training._shared.data.data_loaders.image_caption_sd_dataloader import (
sd_image_caption_collate_fn as flux_image_caption_collate_fn,
)
from invoke_training._shared.data.datasets.build_dataset import (
build_hf_hub_image_caption_dataset,
build_image_caption_dir_dataset,
build_image_caption_jsonl_dataset,
)
from invoke_training._shared.data.datasets.transform_dataset import TransformDataset
from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import (
AspectRatioBucketBatchSampler,
)
from invoke_training._shared.data.transforms.caption_prefix_transform import CaptionPrefixTransform
from invoke_training._shared.data.transforms.drop_field_transform import DropFieldTransform
from invoke_training._shared.data.transforms.flux_image_transform import FluxImageTransform
from invoke_training._shared.data.transforms.load_cache_transform import LoadCacheTransform
from invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache
from invoke_training.config.data.data_loader_config import ImageCaptionFluxDataLoaderConfig
from invoke_training.config.data.dataset_config import (
HFHubImageCaptionDatasetConfig,
ImageCaptionDirDatasetConfig,
ImageCaptionJsonlDatasetConfig,
)
def build_image_caption_flux_dataloader( # noqa: C901
config: ImageCaptionFluxDataLoaderConfig,
batch_size: int,
use_masks: bool = False,
text_encoder_output_cache_dir: typing.Optional[str] = None,
text_encoder_cache_field_to_output_field: typing.Optional[dict[str, str]] = None,
vae_output_cache_dir: typing.Optional[str] = None,
shuffle: bool = True,
) -> DataLoader:
"""Construct a DataLoader for an image-caption dataset for Flux.1-dev.
Args:
config (ImageCaptionFluxDataLoaderConfig): The dataset config.
batch_size (int): The DataLoader batch size.
text_encoder_output_cache_dir (str, optional): The directory where text encoder outputs are cached and should be
loaded from. If set, then the TokenizeTransform will not be applied.
vae_output_cache_dir (str, optional): The directory where VAE outputs are cached and should be loaded from. If
set, then the image augmentation transforms will be skipped, and the image will not be copied to VRAM.
shuffle (bool, optional): Whether to shuffle the dataset order.
Returns:
DataLoader
"""
if isinstance(config.dataset, HFHubImageCaptionDatasetConfig):
base_dataset = build_hf_hub_image_caption_dataset(config.dataset)
elif isinstance(config.dataset, ImageCaptionJsonlDatasetConfig):
base_dataset = build_image_caption_jsonl_dataset(config.dataset)
elif isinstance(config.dataset, ImageCaptionDirDatasetConfig):
base_dataset = build_image_caption_dir_dataset(config.dataset)
else:
raise ValueError(f"Unexpected dataset config type: '{type(config.dataset)}'.")
# Initialize either the fixed target resolution or aspect ratio buckets.
if config.aspect_ratio_buckets is None:
aspect_ratio_bucket_manager = None
batch_sampler = None
else:
aspect_ratio_bucket_manager = build_aspect_ratio_bucket_manager(config=config.aspect_ratio_buckets)
# TODO(ryand): Drill-down the seed parameter rather than hard-coding to 0 here.
batch_sampler = AspectRatioBucketBatchSampler.from_image_sizes(
bucket_manager=aspect_ratio_bucket_manager,
image_sizes=base_dataset.get_image_dimensions(),
batch_size=batch_size,
shuffle=shuffle,
seed=0,
)
all_transforms = []
if config.caption_prefix is not None:
all_transforms.append(CaptionPrefixTransform(caption_field_name="caption", prefix=config.caption_prefix + " "))
if vae_output_cache_dir is None:
image_field_names = ["image"]
if use_masks:
image_field_names.append("mask")
else:
all_transforms.append(DropFieldTransform("mask"))
all_transforms.append(
FluxImageTransform(
image_field_names=image_field_names,
fields_to_normalize_to_range_minus_one_to_one=["image"],
resolution=config.resolution,
aspect_ratio_bucket_manager=aspect_ratio_bucket_manager,
center_crop=config.center_crop,
random_flip=config.random_flip,
)
)
else:
# We drop the image to avoid having to either convert from PIL, or handle PIL batch collation.
all_transforms.append(DropFieldTransform("image"))
all_transforms.append(DropFieldTransform("mask"))
vae_cache = TensorDiskCache(vae_output_cache_dir)
cache_field_to_output_field = {
"vae_output": "vae_output",
"original_size_hw": "original_size_hw",
"crop_top_left_yx": "crop_top_left_yx",
}
if use_masks:
cache_field_to_output_field["mask"] = "mask"
all_transforms.append(
LoadCacheTransform(
cache=vae_cache,
cache_key_field="id",
cache_field_to_output_field=cache_field_to_output_field,
)
)
if text_encoder_output_cache_dir is not None:
assert text_encoder_cache_field_to_output_field is not None
text_encoder_cache = TensorDiskCache(text_encoder_output_cache_dir)
all_transforms.append(
LoadCacheTransform(
cache=text_encoder_cache,
cache_key_field="id",
cache_field_to_output_field=text_encoder_cache_field_to_output_field,
)
)
dataset = TransformDataset(base_dataset, all_transforms)
if batch_sampler is None:
return DataLoader(
dataset,
shuffle=shuffle,
collate_fn=flux_image_caption_collate_fn,
batch_size=batch_size,
num_workers=config.dataloader_num_workers,
)
else:
return DataLoader(
dataset,
batch_sampler=batch_sampler,
collate_fn=flux_image_caption_collate_fn,
num_workers=config.dataloader_num_workers,
)
================================================
FILE: src/invoke_training/_shared/data/data_loaders/image_caption_sd_dataloader.py
================================================
import typing
import torch
from torch.utils.data import DataLoader
from invoke_training._shared.data.datasets.build_dataset import (
build_hf_hub_image_caption_dataset,
build_image_caption_dir_dataset,
build_image_caption_jsonl_dataset,
)
from invoke_training._shared.data.datasets.transform_dataset import TransformDataset
from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import AspectRatioBucketBatchSampler
from invoke_training._shared.data.transforms.caption_prefix_transform import CaptionPrefixTransform
from invoke_training._shared.data.transforms.drop_field_transform import DropFieldTransform
from invoke_training._shared.data.transforms.load_cache_transform import LoadCacheTransform
from invoke_training._shared.data.transforms.sd_image_transform import SDImageTransform
from invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache
from invoke_training._shared.data.utils.aspect_ratio_bucket_manager import AspectRatioBucketManager
from invoke_training.config.data.data_loader_config import AspectRatioBucketConfig, ImageCaptionSDDataLoaderConfig
from invoke_training.config.data.dataset_config import (
HFHubImageCaptionDatasetConfig,
ImageCaptionDirDatasetConfig,
ImageCaptionJsonlDatasetConfig,
)
def sd_image_caption_collate_fn(examples):
"""A batch collation function for the image-caption SDXL data loader."""
out_examples = {
"id": [example["id"] for example in examples],
}
if "image" in examples[0]:
out_examples["image"] = torch.stack([example["image"] for example in examples])
if "original_size_hw" in examples[0]:
out_examples["original_size_hw"] = [example["original_size_hw"] for example in examples]
if "crop_top_left_yx" in examples[0]:
out_examples["crop_top_left_yx"] = [example["crop_top_left_yx"] for example in examples]
if "caption" in examples[0]:
out_examples["caption"] = [example["caption"] for example in examples]
if "loss_weight" in examples[0]:
out_examples["loss_weight"] = torch.tensor([example["loss_weight"] for example in examples])
if "prompt_embeds" in examples[0]:
out_examples["prompt_embeds"] = torch.stack([example["prompt_embeds"] for example in examples])
out_examples["pooled_prompt_embeds"] = torch.stack([example["pooled_prompt_embeds"] for example in examples])
if "text_encoder_output" in examples[0]:
out_examples["text_encoder_output"] = torch.stack([example["text_encoder_output"] for example in examples])
if "vae_output" in examples[0]:
out_examples["vae_output"] = torch.stack([example["vae_output"] for example in examples])
if "mask" in examples[0]:
out_examples["mask"] = torch.stack([example["mask"] for example in examples])
return out_examples
def build_aspect_ratio_bucket_manager(config: AspectRatioBucketConfig):
return AspectRatioBucketManager.from_constraints(
target_resolution=config.target_resolution,
start_dim=config.start_dim,
end_dim=config.end_dim,
divisible_by=config.divisible_by,
)
def build_image_caption_sd_dataloader( # noqa: C901
config: ImageCaptionSDDataLoaderConfig,
batch_size: int,
use_masks: bool = False,
text_encoder_output_cache_dir: typing.Optional[str] = None,
text_encoder_cache_field_to_output_field: typing.Optional[dict[str, str]] = None,
vae_output_cache_dir: typing.Optional[str] = None,
shuffle: bool = True,
) -> DataLoader:
"""Construct a DataLoader for an image-caption dataset for Stable Diffusion XL.
Args:
config (ImageCaptionSDDataLoaderConfig): The dataset config.
batch_size (int): The DataLoader batch size.
text_encoder_output_cache_dir (str, optional): The directory where text encoder outputs are cached and should be
loaded from. If set, then the TokenizeTransform will not be applied.
vae_output_cache_dir (str, optional): The directory where VAE outputs are cached and should be loaded from. If
set, then the image augmentation transforms will be skipped, and the image will not be copied to VRAM.
shuffle (bool, optional): Whether to shuffle the dataset order.
Returns:
DataLoader
"""
if isinstance(config.dataset, HFHubImageCaptionDatasetConfig):
base_dataset = build_hf_hub_image_caption_dataset(config.dataset)
elif isinstance(config.dataset, ImageCaptionJsonlDatasetConfig):
base_dataset = build_image_caption_jsonl_dataset(config.dataset)
elif isinstance(config.dataset, ImageCaptionDirDatasetConfig):
base_dataset = build_image_caption_dir_dataset(config.dataset)
else:
raise ValueError(f"Unexpected dataset config type: '{type(config.dataset)}'.")
# Initialize either the fixed target resolution or aspect ratio buckets.
if config.aspect_ratio_buckets is None:
target_resolution = config.resolution
aspect_ratio_bucket_manager = None
batch_sampler = None
else:
target_resolution = None
aspect_ratio_bucket_manager = build_aspect_ratio_bucket_manager(config=config.aspect_ratio_buckets)
# TODO(ryand): Drill-down the seed parameter rather than hard-coding to 0 here.
batch_sampler = AspectRatioBucketBatchSampler.from_image_sizes(
bucket_manager=aspect_ratio_bucket_manager,
image_sizes=base_dataset.get_image_dimensions(),
batch_size=batch_size,
shuffle=shuffle,
seed=0,
)
all_transforms = []
if config.caption_prefix is not None:
all_transforms.append(CaptionPrefixTransform(caption_field_name="caption", prefix=config.caption_prefix + " "))
if vae_output_cache_dir is None:
image_field_names = ["image"]
if use_masks:
image_field_names.append("mask")
else:
all_transforms.append(DropFieldTransform("mask"))
all_transforms.append(
SDImageTransform(
image_field_names=image_field_names,
fields_to_normalize_to_range_minus_one_to_one=["image"],
resolution=target_resolution,
aspect_ratio_bucket_manager=aspect_ratio_bucket_manager,
center_crop=config.center_crop,
random_flip=config.random_flip,
)
)
else:
# We drop the image to avoid having to either convert from PIL, or handle PIL batch collation.
all_transforms.append(DropFieldTransform("image"))
all_transforms.append(DropFieldTransform("mask"))
vae_cache = TensorDiskCache(vae_output_cache_dir)
cache_field_to_output_field = {
"vae_output": "vae_output",
"original_size_hw": "original_size_hw",
"crop_top_left_yx": "crop_top_left_yx",
}
if use_masks:
cache_field_to_output_field["mask"] = "mask"
all_transforms.append(
LoadCacheTransform(
cache=vae_cache,
cache_key_field="id",
cache_field_to_output_field=cache_field_to_output_field,
)
)
if text_encoder_output_cache_dir is not None:
assert text_encoder_cache_field_to_output_field is not None
text_encoder_cache = TensorDiskCache(text_encoder_output_cache_dir)
all_transforms.append(
LoadCacheTransform(
cache=text_encoder_cache,
cache_key_field="id",
cache_field_to_output_field=text_encoder_cache_field_to_output_field,
)
)
dataset = TransformDataset(base_dataset, all_transforms)
if batch_sampler is None:
return DataLoader(
dataset,
shuffle=shuffle,
collate_fn=sd_image_caption_collate_fn,
batch_size=batch_size,
num_workers=config.dataloader_num_workers,
)
else:
return DataLoader(
dataset,
batch_sampler=batch_sampler,
collate_fn=sd_image_caption_collate_fn,
num_workers=config.dataloader_num_workers,
)
================================================
FILE: src/invoke_training/_shared/data/data_loaders/image_pair_preference_sd_dataloader.py
================================================
import typing
import torch
from torch.utils.data import DataLoader
from invoke_training._shared.data.datasets.build_dataset import build_hf_image_pair_preference_dataset
from invoke_training._shared.data.datasets.image_pair_preference_dataset import ImagePairPreferenceDataset
from invoke_training._shared.data.datasets.transform_dataset import TransformDataset
from invoke_training._shared.data.transforms.load_cache_transform import LoadCacheTransform
from invoke_training._shared.data.transforms.sd_image_transform import SDImageTransform
from invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache
from invoke_training.pipelines._experimental.sd_dpo_lora.config import ImagePairPreferenceSDDataLoaderConfig
def sd_image_pair_preference_collate_fn(examples):
"""A batch collation function."""
stack_keys = {"image_0", "image_1", "prompt_embeds", "pooled_prompt_embeds", "text_encoder_output", "vae_output"}
list_keys = {
"id",
"original_size_hw_0",
"original_size_hw_1",
"crop_top_left_yx_0",
"crop_top_left_yx_1",
"prefer_0",
"prefer_1",
"caption",
}
unhandled_keys = set(examples[0].keys()) - (stack_keys | list_keys)
if len(unhandled_keys) > 0:
raise ValueError(f"The following keys are not handled by the collate function: {unhandled_keys}.")
out_examples = {}
# torch.stack(...)
for k in stack_keys:
if k in examples[0]:
out_examples[k] = torch.stack([example[k] for example in examples])
# Basic list.
for k in list_keys:
if k in examples[0]:
out_examples[k] = [example[k] for example in examples]
return out_examples
def build_image_pair_preference_sd_dataloader(
config: ImagePairPreferenceSDDataLoaderConfig,
batch_size: int,
text_encoder_output_cache_dir: typing.Optional[str] = None,
text_encoder_cache_field_to_output_field: typing.Optional[dict[str, str]] = None,
vae_output_cache_dir: typing.Optional[str] = None,
shuffle: bool = True,
) -> DataLoader:
"""Construct a DataLoader for an image-caption dataset for Stable Diffusion XL.
Args:
config (ImageCaptionSDDataLoaderConfig): The dataset config.
batch_size (int): The DataLoader batch size.
text_encoder_output_cache_dir (str, optional): The directory where text encoder outputs are cached and should be
loaded from. If set, then the TokenizeTransform will not be applied.
vae_output_cache_dir (str, optional): The directory where VAE outputs are cached and should be loaded from. If
set, then the image augmentation transforms will be skipped, and the image will not be copied to VRAM.
shuffle (bool, optional): Whether to shuffle the dataset order.
Returns:
DataLoader
"""
if config.dataset.type == "HF_HUB_IMAGE_PAIR_PREFERENCE_DATASET":
base_dataset = build_hf_image_pair_preference_dataset(config=config.dataset)
elif config.dataset.type == "IMAGE_PAIR_PREFERENCE_DATASET":
base_dataset = ImagePairPreferenceDataset(dataset_dir=config.dataset.dataset_dir)
else:
raise ValueError(f"Unexpected dataset config type: '{type(config.dataset)}'.")
target_resolution = config.resolution
all_transforms = []
if vae_output_cache_dir is None:
# TODO(ryand): Should I process both images in a single SDImageTransform so that they undergo the same
# transformations?
all_transforms.append(
SDImageTransform(
image_field_names=["image_0"],
fields_to_normalize_to_range_minus_one_to_one=["image_0"],
resolution=target_resolution,
aspect_ratio_bucket_manager=None,
center_crop=config.center_crop,
random_flip=config.random_flip,
orig_size_field_name="original_size_hw_0",
crop_field_name="crop_top_left_yx_0",
)
)
all_transforms.append(
SDImageTransform(
image_field_names=["image_1"],
fields_to_normalize_to_range_minus_one_to_one=["image_1"],
resolution=target_resolution,
aspect_ratio_bucket_manager=None,
center_crop=config.center_crop,
random_flip=config.random_flip,
orig_size_field_name="original_size_hw_1",
crop_field_name="crop_top_left_yx_1",
)
)
else:
raise NotImplementedError("VAE caching is not yet implemented.")
# vae_cache = TensorDiskCache(vae_output_cache_dir)
# all_transforms.append(
# LoadCacheTransform(
# cache=vae_cache,
# cache_key_field="id",
# cache_field_to_output_field={
# "vae_output": "vae_output",
# "original_size_hw": "original_size_hw",
# "crop_top_left_yx": "crop_top_left_yx",
# },
# )
# )
# # We drop the image to avoid having to either convert from PIL, or handle PIL batch collation.
# all_transforms.append(DropFieldTransform("image"))
if text_encoder_output_cache_dir is not None:
assert text_encoder_cache_field_to_output_field is not None
text_encoder_cache = TensorDiskCache(text_encoder_output_cache_dir)
all_transforms.append(
LoadCacheTransform(
cache=text_encoder_cache,
cache_key_field="id",
cache_field_to_output_field=text_encoder_cache_field_to_output_field,
)
)
dataset = TransformDataset(base_dataset, all_transforms)
return DataLoader(
dataset,
shuffle=shuffle,
collate_fn=sd_image_pair_preference_collate_fn,
batch_size=batch_size,
num_workers=config.dataloader_num_workers,
)
================================================
FILE: src/invoke_training/_shared/data/data_loaders/textual_inversion_sd_dataloader.py
================================================
from typing import Literal, Optional
from torch.utils.data import DataLoader
from invoke_training._shared.data.data_loaders.image_caption_sd_dataloader import (
build_aspect_ratio_bucket_manager,
sd_image_caption_collate_fn,
)
from invoke_training._shared.data.datasets.build_dataset import (
build_hf_hub_image_caption_dataset,
build_image_caption_dir_dataset,
build_image_caption_jsonl_dataset,
)
from invoke_training._shared.data.datasets.image_dir_dataset import ImageDirDataset
from invoke_training._shared.data.datasets.transform_dataset import TransformDataset
from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import AspectRatioBucketBatchSampler
from invoke_training._shared.data.transforms.concat_fields_transform import ConcatFieldsTransform
from invoke_training._shared.data.transforms.drop_field_transform import DropFieldTransform
from invoke_training._shared.data.transforms.load_cache_transform import LoadCacheTransform
from invoke_training._shared.data.transforms.sd_image_transform import SDImageTransform
from invoke_training._shared.data.transforms.shuffle_caption_transform import ShuffleCaptionTransform
from invoke_training._shared.data.transforms.template_caption_transform import TemplateCaptionTransform
from invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache
from invoke_training.config.data.data_loader_config import TextualInversionSDDataLoaderConfig
from invoke_training.config.data.dataset_config import (
HFHubImageCaptionDatasetConfig,
ImageCaptionDirDatasetConfig,
ImageCaptionJsonlDatasetConfig,
ImageDirDatasetConfig,
)
def get_preset_ti_caption_templates(preset: Literal["object", "style"]) -> list[str]:
if preset == "object":
return [
"a photo of a {}",
"a rendering of a {}",
"a cropped photo of the {}",
"the photo of a {}",
"a photo of a clean {}",
"a photo of a dirty {}",
"a dark photo of the {}",
"a photo of my {}",
"a photo of the cool {}",
"a close-up photo of a {}",
"a bright photo of the {}",
"a cropped photo of a {}",
"a photo of the {}",
"a good photo of the {}",
"a photo of one {}",
"a close-up photo of the {}",
"a rendition of the {}",
"a photo of the clean {}",
"a rendition of a {}",
"a photo of a nice {}",
"a good photo of a {}",
"a photo of the nice {}",
"a photo of the small {}",
"a photo of the weird {}",
"a photo of the large {}",
"a photo of a cool {}",
"a photo of a small {}",
]
elif preset == "style":
return [
"a painting in the style of {}",
"a rendering in the style of {}",
"a cropped painting in the style of {}",
"the painting in the style of {}",
"a clean painting in the style of {}",
"a dirty painting in the style of {}",
"a dark painting in the style of {}",
"a picture in the style of {}",
"a cool painting in the style of {}",
"a close-up painting in the style of {}",
"a bright painting in the style of {}",
"a good painting in the style of {}",
"a close-up painting in the style of {}",
"a rendition in the style of {}",
"a nice painting in the style of {}",
"a small painting in the style of {}",
"a weird painting in the style of {}",
"a large painting in the style of {}",
"a photo in the style of {}",
"an image in the style of {}",
"a drawing in the style of {}",
"a sketch in the style of {}",
"a digital work in the style of {}",
"a digital rendering in the style of {}",
"a photograph in the style of {}",
"photography in the style of {}",
]
else:
raise ValueError(f"Unrecognized learnable property type: '{preset}'.")
def build_textual_inversion_sd_dataloader( # noqa: C901
config: TextualInversionSDDataLoaderConfig,
placeholder_token: str,
batch_size: int,
use_masks: bool = False,
vae_output_cache_dir: Optional[str] = None,
shuffle: bool = True,
) -> DataLoader:
"""Construct a DataLoader for a Textual Inversion dataset for Stable Diffusion.
Args:
config (TextualInversionSDDataLoaderConfig): The dataset config.
placeholder_token (str): The placeholder token being trained.
batch_size (int): The DataLoader batch size.
vae_output_cache_dir (str, optional): The directory where VAE outputs are cached and should be loaded from. If
set, then the image augmentation transforms will be skipped, and the image will not be copied to VRAM.
shuffle (bool, optional): Whether to shuffle the dataset order.
Returns:
DataLoader
"""
if isinstance(config.dataset, HFHubImageCaptionDatasetConfig):
base_dataset = build_hf_hub_image_caption_dataset(config.dataset)
elif isinstance(config.dataset, ImageCaptionJsonlDatasetConfig):
base_dataset = build_image_caption_jsonl_dataset(config.dataset)
elif isinstance(config.dataset, ImageCaptionDirDatasetConfig):
base_dataset = build_image_caption_dir_dataset(config.dataset)
elif isinstance(config.dataset, ImageDirDatasetConfig):
base_dataset = ImageDirDataset(
image_dir=config.dataset.dataset_dir, keep_in_memory=config.dataset.keep_in_memory
)
else:
raise ValueError(f"Unexpected dataset config type: '{type(config.dataset)}'.")
# Initialize either the fixed target resolution or aspect ratio buckets.
if config.aspect_ratio_buckets is None:
target_resolution = config.resolution
aspect_ratio_bucket_manager = None
batch_sampler = None
else:
target_resolution = None
aspect_ratio_bucket_manager = build_aspect_ratio_bucket_manager(config=config.aspect_ratio_buckets)
# TODO(ryand): Drill-down the seed parameter rather than hard-coding to 0 here.
batch_sampler = AspectRatioBucketBatchSampler.from_image_sizes(
bucket_manager=aspect_ratio_bucket_manager,
image_sizes=base_dataset.get_image_dimensions(),
batch_size=batch_size,
shuffle=shuffle,
seed=0,
)
if sum([config.caption_templates is not None, config.caption_preset is not None]) != 1:
raise ValueError("Either caption_templates or caption_preset must be set.")
if config.caption_templates is not None:
# Overwrites the caption field. Typically used with a ImageDirDataset that does not have captions.
caption_tf = TemplateCaptionTransform(
field_name="caption_prefix" if config.keep_original_captions else "caption",
placeholder_str=placeholder_token,
caption_templates=config.caption_templates,
)
elif config.caption_preset is not None:
# Overwrites the caption field. Typically used with a ImageDirDataset that does not have captions.
caption_tf = TemplateCaptionTransform(
field_name="caption_prefix" if config.keep_original_captions else "caption",
placeholder_str=placeholder_token,
caption_templates=get_preset_ti_caption_templates(config.caption_preset),
)
else:
raise ValueError("Either caption_templates or caption_preset must be set.")
all_transforms = [caption_tf]
if config.keep_original_captions:
# This will only work with a HFHubImageCaptionDataset or HFDirImageCaptionDataset that already has captions.
all_transforms.append(
ConcatFieldsTransform(
src_field_names=["caption_prefix", "caption"], dst_field_name="caption", separator=" "
)
)
if config.shuffle_caption_delimiter is not None:
all_transforms.append(ShuffleCaptionTransform(field_name="caption", delimiter=config.shuffle_caption_delimiter))
if vae_output_cache_dir is None:
image_field_names = ["image"]
if use_masks:
image_field_names.append("mask")
else:
all_transforms.append(DropFieldTransform("mask"))
all_transforms.append(
SDImageTransform(
image_field_names=image_field_names,
fields_to_normalize_to_range_minus_one_to_one=["image"],
resolution=target_resolution,
aspect_ratio_bucket_manager=aspect_ratio_bucket_manager,
center_crop=config.center_crop,
random_flip=config.random_flip,
)
)
else:
# We drop the image to avoid having to either convert from PIL, or handle PIL batch collation.
all_transforms.append(DropFieldTransform("image"))
all_transforms.append(DropFieldTransform("mask"))
vae_cache = TensorDiskCache(vae_output_cache_dir)
cache_field_to_output_field = {
"vae_output": "vae_output",
"original_size_hw": "original_size_hw",
"crop_top_left_yx": "crop_top_left_yx",
}
if use_masks:
cache_field_to_output_field["mask"] = "mask"
all_transforms.append(
LoadCacheTransform(
cache=vae_cache,
cache_key_field="id",
cache_field_to_output_field=cache_field_to_output_field,
)
)
dataset = TransformDataset(base_dataset, all_transforms)
if batch_sampler is None:
return DataLoader(
dataset,
shuffle=shuffle,
collate_fn=sd_image_caption_collate_fn,
batch_size=batch_size,
num_workers=config.dataloader_num_workers,
persistent_workers=config.dataloader_num_workers > 0,
)
else:
return DataLoader(
dataset,
batch_sampler=batch_sampler,
collate_fn=sd_image_caption_collate_fn,
num_workers=config.dataloader_num_workers,
persistent_workers=config.dataloader_num_workers > 0,
)
================================================
FILE: src/invoke_training/_shared/data/datasets/__init__.py
================================================
================================================
FILE: src/invoke_training/_shared/data/datasets/build_dataset.py
================================================
from datasets import VerificationMode
from invoke_training._shared.data.datasets.hf_image_caption_dataset import HFImageCaptionDataset
from invoke_training._shared.data.datasets.hf_image_pair_preference_dataset import HFImagePairPreferenceDataset
from invoke_training._shared.data.datasets.image_caption_dir_dataset import ImageCaptionDirDataset
from invoke_training._shared.data.datasets.image_caption_jsonl_dataset import ImageCaptionJsonlDataset
from invoke_training.config.data.dataset_config import (
HFHubImageCaptionDatasetConfig,
ImageCaptionDirDatasetConfig,
ImageCaptionJsonlDatasetConfig,
)
from invoke_training.pipelines._experimental.sd_dpo_lora.config import HFHubImagePairPreferenceDatasetConfig
def build_hf_hub_image_caption_dataset(config: HFHubImageCaptionDatasetConfig) -> HFImageCaptionDataset:
return HFImageCaptionDataset.from_hub(
dataset_name=config.dataset_name,
hf_load_dataset_kwargs={
"name": config.dataset_config_name,
"cache_dir": config.hf_cache_dir,
},
image_column=config.image_column,
caption_column=config.caption_column,
)
def build_image_caption_jsonl_dataset(config: ImageCaptionJsonlDatasetConfig) -> HFImageCaptionDataset:
return ImageCaptionJsonlDataset(
jsonl_path=config.jsonl_path,
image_column=config.image_column,
caption_column=config.caption_column,
keep_in_memory=config.keep_in_memory,
)
def build_image_caption_dir_dataset(config: ImageCaptionDirDatasetConfig) -> ImageCaptionDirDataset:
return ImageCaptionDirDataset(
dataset_dir=config.dataset_dir,
keep_in_memory=config.keep_in_memory,
)
def build_hf_image_pair_preference_dataset(
config: HFHubImagePairPreferenceDatasetConfig,
) -> HFImagePairPreferenceDataset:
# HACK(ryand): This is currently hard-coded to just download a small slice of the very large
# 'yuvalkirstain/pickapic_v2' dataset.
return HFImagePairPreferenceDataset.from_hub(
"yuvalkirstain/pickapic_v2",
split="train",
hf_load_dataset_kwargs={
"data_files": {
# "validation_unique": "data/validation_unique-00000-of-00001-33ead111845fc9c4.parquet",
"train": [
"data/train-00000-of-00645-b66ac786bf6fb553.parquet",
"data/train-00001-of-00645-c7b349dd222d6515.parquet",
"data/train-00002-of-00645-e4f54d615a978deb.parquet",
"data/train-00003-of-00645-2b9d59bac8b433ff.parquet",
"data/train-00004-of-00645-e4964649dc0ea543.parquet",
"data/train-00005-of-00645-45e8efc0fe93f6e9.parquet",
]
},
# Disable checks so that it doesn't complain that I haven't downloaded the other splits.
"verification_mode": VerificationMode.NO_CHECKS,
},
)
================================================
FILE: src/invoke_training/_shared/data/datasets/hf_image_caption_dataset.py
================================================
import os
import typing
import datasets
import torch.utils.data
from PIL.Image import Image
from invoke_training._shared.data.utils.resolution import Resolution
class HFImageCaptionDataset(torch.utils.data.Dataset):
"""An image-caption dataset wrapper for Hugging Face datasets.
The wrapped HF dataset can be either from the HF hub, or in Imagefolder format
(https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder).
"""
def __init__(self, hf_dataset, image_column: str = "image", caption_column: str = "text"):
column_names = hf_dataset["train"].column_names
if image_column not in column_names:
raise ValueError(
f"The image_column='{image_column}' is not in the set of dataset column names: '{column_names}'."
)
if caption_column not in column_names:
raise ValueError(
f"The caption_column='{caption_column}' is not in the set of dataset column names: '{column_names}'."
)
self._image_column = image_column
def preprocess(examples):
images = [image.convert("RGB") for image in examples[image_column]]
return {
"image": images,
"caption": examples[caption_column],
}
self._hf_dataset = hf_dataset["train"].with_transform(preprocess)
@classmethod
def from_dir(
cls,
dataset_dir: str,
hf_load_dataset_kwargs: typing.Optional[dict[str, typing.Any]] = None,
image_column: str = "image",
caption_column: str = "text",
):
"""Initialize a HFImageCaptionDataset from a Hugging Face ImageFolder dataset directory
(https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder).
Args:
dataset_dir (str): The path to the dataset directory.
hf_load_dataset_kwargs (dict[str, typing.Any], optional): kwargs to forward to `datasets.load_dataset(...)`.
image_column (str, optional): The name of the image column in the dataset. Defaults to "image".
caption_column (str, optional): The name of the caption column in the dataset. Defaults to "text".
"""
hf_load_dataset_kwargs = hf_load_dataset_kwargs or {}
data_files = {"train": os.path.join(dataset_dir, "**")}
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
hf_dataset = datasets.load_dataset("imagefolder", data_files=data_files, **hf_load_dataset_kwargs)
return cls(hf_dataset=hf_dataset, image_column=image_column, caption_column=caption_column)
@classmethod
def from_hub(
cls,
dataset_name: str,
hf_load_dataset_kwargs: typing.Optional[dict[str, typing.Any]] = None,
image_column: str = "image",
caption_column: str = "text",
):
"""Initialize a HFImageCaptionDataset from a Hugging Face Hub dataset.
Args:
dataset_name (str): The HF Hub dataset name (a.k.a. path).
hf_load_dataset_kwargs (dict[str, typing.Any], optional): kwargs to forward to `datasets.load_dataset(...)`.
image_column (str, optional): The name of the image column in the dataset. Defaults to "image".
caption_column (str, optional): The name of the caption column in the dataset. Defaults to "text".
"""
hf_load_dataset_kwargs = hf_load_dataset_kwargs or {}
hf_dataset = datasets.load_dataset(dataset_name, **hf_load_dataset_kwargs)
return cls(hf_dataset=hf_dataset, image_column=image_column, caption_column=caption_column)
def get_image_dimensions(self) -> list[Resolution]:
"""Get the dimensions of all images in the dataset.
TODO(ryand): Re-think this approach. For large datasets (e.g. streaming from S3) it doesn't make sense to
calculate this dynamically every time.
"""
image_dims: list[Resolution] = []
for i in range(len(self._hf_dataset)):
example = self._hf_dataset[i]
image: Image = example[self._image_column]
image_dims.append(Resolution(image.height, image.width))
return image_dims
def __len__(self) -> int:
"""Get the dataset length.
Returns:
int: The number of image-caption pairs in the dataset.
"""
return len(self._hf_dataset)
def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:
"""Load the dataset example at index `idx`.
Raises:
IndexError: If `idx` is out of range.
Returns:
dict: A dataset example with 3 keys: "image", "caption", and "id".
The "image" key maps to a `PIL` image in RGB format.
The "caption" key maps to a string.
The "id" key is the example's index (often used for caching).
"""
example = self._hf_dataset[idx]
example["id"] = idx
return example
================================================
FILE: src/invoke_training/_shared/data/datasets/hf_image_pair_preference_dataset.py
================================================
import io
import typing
import datasets
import torch.utils.data
from PIL import Image
class HFImagePairPreferenceDataset(torch.utils.data.Dataset):
"""A wrapper for the Hugging Face hub "yuvalkirstain/pickapic_v2" dataset
(https://huggingface.co/datasets/yuvalkirstain/pickapic_v2).
Designed to be expanded in the future to other HF image pair preference datasets.
"""
def __init__(
self,
hf_dataset,
skip_no_preference=True,
split: str = "train",
image_0_column: str = "jpg_0",
label_0_column: str = "label_0",
image_1_column: str = "jpg_1",
label_1_column: str = "jpg_1",
caption_column: str = "caption",
):
"""
Args:
skip_no_preference (bool, optional): If True, skip image pairs without a preference.
"""
column_names = hf_dataset[split].column_names
for col_name in [image_0_column, label_0_column, image_1_column, label_1_column, caption_column]:
if col_name not in column_names:
raise ValueError(f"Column '{col_name}' is not in the set of dataset column names: '{column_names}'.")
eps = 0.0001
if skip_no_preference:
# Filter to only include pairs with a clear preference.
def filter(example: dict[str, typing.Any]) -> bool:
return abs(example["label_0"] - example["label_1"]) > eps
hf_dataset = hf_dataset.filter(filter)
def preprocess(examples):
image_0_list = [Image.open(io.BytesIO(image)).convert("RGB") for image in examples[image_0_column]]
image_1_list = [Image.open(io.BytesIO(image)).convert("RGB") for image in examples[image_1_column]]
image_0_is_better = []
image_1_is_better = []
for label_0, label_1 in zip(examples["label_0"], examples["label_1"]):
if (label_0 - label_1) > eps:
# Label 0 is better.
image_0_is_better.append(True)
image_1_is_better.append(False)
elif (label_1 - label_0) > eps:
# Label 1 is better.
image_0_is_better.append(False)
image_1_is_better.append(True)
else:
# Tie.
image_0_is_better.append(False)
image_1_is_better.append(False)
return {
"image_0": image_0_list,
"image_1": image_1_list,
"prefer_0": image_0_is_better,
"prefer_1": image_1_is_better,
"caption": examples[caption_column],
}
self._hf_dataset = hf_dataset[split].with_transform(preprocess)
@classmethod
def from_hub(
cls,
dataset_name: str,
skip_no_preference: bool = True,
split: str = "train",
hf_load_dataset_kwargs: typing.Optional[dict[str, typing.Any]] = None,
):
"""Initialize a HFImageCaptionDataset from a Hugging Face Hub dataset.
Args:
dataset_name (str): The HF Hub dataset name (a.k.a. path).
hf_load_dataset_kwargs (dict[str, typing.Any], optional): kwargs to forward to `datasets.load_dataset(...)`.
"""
if dataset_name != "yuvalkirstain/pickapic_v2":
raise NotImplementedError(
"The HFImagePairPreferenceDataset class likely won't work with datasets other than "
"'yuvalkirstain/pickapic_v2'."
)
hf_load_dataset_kwargs = hf_load_dataset_kwargs or {}
hf_dataset = datasets.load_dataset(dataset_name, **hf_load_dataset_kwargs)
return cls(hf_dataset=hf_dataset, skip_no_preference=skip_no_preference, split=split)
def __len__(self) -> int:
"""Get the dataset length.
Returns:
int: The number of image pairs in the dataset.
"""
return len(self._hf_dataset)
def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:
"""Load the dataset example at index `idx`.
Raises:
IndexError: If `idx` is out of range.
Returns:
dict: A dataset example with the following keys: ["id", "image_1", "caption_1", "image_2", "caption_2",
"prefer_1", "prefer_2"]
The image keys map to a `PIL` image in RGB format.
The caption keys map to strings.
The "id" key is the example's index (often used for caching).
"""
example = self._hf_dataset[idx]
example["id"] = idx
return example
================================================
FILE: src/invoke_training/_shared/data/datasets/image_caption_dir_dataset.py
================================================
import os
import typing
import torch.utils.data
from PIL import Image
from invoke_training._shared.data.utils.resolution import Resolution
class ImageCaptionDirDataset(torch.utils.data.Dataset):
"""A dataset that loads images and captions from a directory of image files and .txt files."""
def __init__(
self,
dataset_dir: str,
id_prefix: str = "",
image_extensions: typing.Optional[list[str]] = None,
caption_extension: str = ".txt",
keep_in_memory: bool = False,
):
"""Initialize an ImageDirDataset
Args:
image_dir (str): The directory to load images from.
id_prefix (str): A prefix added to the 'id' field in every example.
image_extensions (list[str], optional): The list of image file extensions to include in the dataset (not
case-sensitive). Defaults to [".jpg", ".jpeg", ".png"].
keep_in_memory (bool, optional): If True, keep all images loaded in memory. This improves performance for
datasets that are small enough to be kept in memory.
"""
super().__init__()
self._id_prefix = id_prefix
if image_extensions is None:
image_extensions = [".jpg", ".jpeg", ".png"]
image_extensions = [ext.lower() for ext in image_extensions]
# Determine the list of image paths to include in the dataset.
self._image_paths: list[str] = []
for image_file in os.listdir(dataset_dir):
image_path = os.path.join(dataset_dir, image_file)
if os.path.isfile(image_path) and os.path.splitext(image_path)[1].lower() in image_extensions:
self._image_paths.append(image_path)
self._image_paths.sort()
# Load captions from .txt files for each image.
self._captions: list[str] = []
missing_captions: list[str] = []
for image_path in self._image_paths:
caption_path = os.path.splitext(image_path)[0] + caption_extension
if os.path.isfile(caption_path):
with open(caption_path, "r") as f:
self._captions.append(f.read().strip())
else:
missing_captions.append(caption_path)
if len(missing_captions) > 0:
raise Exception(f"The following expected caption files are missing: {missing_captions}")
self._images = None
if keep_in_memory:
self._images = []
for image_path in self._image_paths:
self._images.append(self._load_image(image_path))
def _load_image(self, image_path: str) -> Image.Image:
# We call `convert("RGB")` to drop the alpha channel from RGBA images, or to repeat channels for greyscale
# images.
return Image.open(image_path).convert("RGB")
def get_image_dimensions(self) -> list[Resolution]:
"""Get the dimensions of all images in the dataset.
TODO(ryand): Re-think this approach. For large datasets (e.g. streaming from S3) it doesn't make sense to
calculate this dynamically every time.
"""
image_dims: list[Resolution] = []
for i in range(len(self._image_paths)):
image_path = self._image_paths[i]
image = Image.open(image_path)
image_dims.append(Resolution(image.height, image.width))
return image_dims
def __len__(self) -> int:
return len(self._image_paths)
def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:
image = self._images[idx] if self._images is not None else self._load_image(self._image_paths[idx])
return {"id": f"{self._id_prefix}{idx}", "image": image, "caption": self._captions[idx]}
================================================
FILE: src/invoke_training/_shared/data/datasets/image_caption_jsonl_dataset.py
================================================
import typing
from pathlib import Path
import torch.utils.data
from PIL import Image
from pydantic import BaseModel
from invoke_training._shared.data.utils.resolution import Resolution
from invoke_training._shared.utils.jsonl import load_jsonl, save_jsonl
IMAGE_COLUMN_DEFAULT = "image"
CAPTION_COLUMN_DEFAULT = "text"
MASK_COLUMN_DEFAULT = "mask"
class ImageCaptionExample(BaseModel):
image_path: str
mask_path: str | None = None
caption: str
class ImageCaptionJsonlDataset(torch.utils.data.Dataset):
"""A dataset that loads images and captions from a directory of image files and .txt files."""
def __init__(
self,
jsonl_path: Path | str,
image_column: str = IMAGE_COLUMN_DEFAULT,
caption_column: str = CAPTION_COLUMN_DEFAULT,
keep_in_memory: bool = False,
):
super().__init__()
self._jsonl_path = Path(jsonl_path)
self._image_column = image_column
self._caption_column = caption_column
data = load_jsonl(jsonl_path)
examples: list[ImageCaptionExample] = []
for d in data:
# Clear error messages here are helpful in the Gradio UI.
if image_column not in d:
raise ValueError(f"Column '{image_column}' not found in jsonl file '{jsonl_path}'.")
if caption_column not in d:
raise ValueError(f"Column '{caption_column}' not found in jsonl file '{jsonl_path}'.")
examples.append(
ImageCaptionExample(
image_path=d[image_column], mask_path=d.get(MASK_COLUMN_DEFAULT, None), caption=d[caption_column]
)
)
self.examples = examples
self._keep_in_memory = keep_in_memory
self._example_cache: dict[int, dict[str, typing.Any]] = {}
def save_jsonl(self):
data = []
for example in self.examples:
data.append(
{
self._image_column: example.image_path,
self._caption_column: example.caption,
MASK_COLUMN_DEFAULT: example.mask_path,
}
)
save_jsonl(data, self._jsonl_path)
def _get_image_path(self, idx: int) -> str:
image_path = self.examples[idx].image_path
image_path = Path(image_path)
# image_path could be either absolute, or relative to the jsonl file.
if not image_path.is_absolute():
image_path = self._jsonl_path.parent / image_path
return image_path
def _get_mask_path(self, idx: int) -> str:
mask_path = self.examples[idx].mask_path
mask_path = Path(mask_path)
# mask_path could be either absolute, or relative to the jsonl file.
if not mask_path.is_absolute():
mask_path = self._jsonl_path.parent / mask_path
return mask_path
def _load_image(self, image_path: str) -> Image.Image:
# We call `convert("RGB")` to drop the alpha channel from RGBA images, or to repeat channels for greyscale
# images.
return Image.open(image_path).convert("RGB")
def _load_mask(self, mask_path: str) -> Image.Image:
return Image.open(mask_path).convert("L")
def _load_example(self, idx: int) -> dict[str, typing.Any]:
example = {
"id": str(idx),
"image": self._load_image(self._get_image_path(idx)),
"caption": self.examples[idx].caption,
}
if self.examples[idx].mask_path:
example["mask"] = self._load_mask(self._get_mask_path(idx))
return example
def get_image_dimensions(self) -> list[Resolution]:
"""Get the dimensions of all images in the dataset.
TODO(ryand): Re-think this approach. For large datasets (e.g. streaming from S3) it doesn't make sense to
calculate this dynamically every time.
"""
image_dims: list[Resolution] = []
for i in range(len(self.examples)):
image = Image.open(self._get_image_path(i))
image_dims.append(Resolution(image.height, image.width))
return image_dims
def __len__(self) -> int:
return len(self.examples)
def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:
if self._keep_in_memory:
if idx not in self._example_cache:
self._example_cache[idx] = self._load_example(idx)
# Return a shallow copy of the example to prevent the caller from modifying the cached example.
# Shallow rather than deep, because we don't want to copy the image data.
return self._example_cache[idx].copy()
return self._load_example(idx)
================================================
FILE: src/invoke_training/_shared/data/datasets/image_dir_dataset.py
================================================
import os
import typing
import torch.utils.data
from PIL import Image
from invoke_training._shared.data.utils.resolution import Resolution
class ImageDirDataset(torch.utils.data.Dataset):
"""A dataset that loads image files from a directory."""
def __init__(
self,
image_dir: str,
id_prefix: str = "",
image_extensions: typing.Optional[list[str]] = None,
keep_in_memory: bool = False,
):
"""Initialize an ImageDirDataset
Args:
image_dir (str): The directory to load images from.
id_prefix (str): A prefix added to the 'id' field in every example.
image_extensions (list[str], optional): The list of image file extensions to include in the dataset (not
case-sensitive). Defaults to [".jpg", ".jpeg", ".png"].
keep_in_memory (bool, optional): If True, keep all images loaded in memory. This improves performance for
datasets that are small enough to be kept in memory.
"""
super().__init__()
self._id_prefix = id_prefix
if image_extensions is None:
image_extensions = [".jpg", ".jpeg", ".png"]
image_extensions = [ext.lower() for ext in image_extensions]
self._image_paths = []
for image_file in os.listdir(image_dir):
image_path = os.path.join(image_dir, image_file)
if os.path.isfile(image_path) and os.path.splitext(image_path)[1].lower() in image_extensions:
self._image_paths.append(image_path)
self._images = None
if keep_in_memory:
self._images = []
for image_path in self._image_paths:
self._images.append(self._load_image(image_path))
def _load_image(self, image_path: str) -> Image.Image:
# We call `convert("RGB")` to drop the alpha channel from RGBA images, or to repeat channels for greyscale
# images.
return Image.open(image_path).convert("RGB")
def get_image_dimensions(self) -> list[Resolution]:
"""Get the dimensions of all images in the dataset.
TODO(ryand): Re-think this approach. For large datasets (e.g. streaming from S3) it doesn't make sense to
calculate this dynamically every time.
"""
image_dims: list[Resolution] = []
for i in range(len(self._image_paths)):
image_path = self._image_paths[i]
image = Image.open(image_path)
image_dims.append(Resolution(image.height, image.width))
return image_dims
def __len__(self) -> int:
return len(self._image_paths)
def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:
image = self._images[idx] if self._images is not None else self._load_image(self._image_paths[idx])
return {"id": f"{self._id_prefix}{idx}", "image": image}
================================================
FILE: src/invoke_training/_shared/data/datasets/image_pair_preference_dataset.py
================================================
import os
import typing
from pathlib import Path
import torch.utils.data
from PIL import Image
from invoke_training._shared.utils.jsonl import load_jsonl, save_jsonl
class ImagePairPreferenceDataset(torch.utils.data.Dataset):
def __init__(self, dataset_dir: str):
super().__init__()
self._dataset_dir = dataset_dir
self._metadata = load_jsonl(Path(dataset_dir) / "metadata.jsonl")
@classmethod
def save_metadata(
cls, metadata: list[dict[str, typing.Any]], dataset_dir: str | Path, metadata_file: str = "metadata.jsonl"
) -> Path:
"""Load the dataset metadata from metadata.jsonl."""
metadata_path = Path(dataset_dir) / metadata_file
save_jsonl(metadata, metadata_path)
return metadata_path
def __len__(self) -> int:
return len(self._metadata)
def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:
# We call `convert("RGB")` to drop the alpha channel from RGBA images, or to repeat channels for greyscale
# images.
example = self._metadata[idx]
image_0_path = os.path.join(self._dataset_dir, example["image_0"])
image_1_path = os.path.join(self._dataset_dir, example["image_1"])
return {
"id": str(idx),
"image_0": Image.open(image_0_path).convert("RGB"),
"image_1": Image.open(image_1_path).convert("RGB"),
"caption": example["prompt"],
"prefer_0": example["prefer_0"],
"prefer_1": example["prefer_1"],
}
================================================
FILE: src/invoke_training/_shared/data/datasets/transform_dataset.py
================================================
import typing
import torch.utils.data
# The data type expected to be produced by the base dataset and handled by transforms.
DataType = typing.Dict[str, typing.Any]
TransformType = typing.Callable[[DataType], DataType]
class TransformDataset(torch.utils.data.Dataset):
"""A Dataset that wraps a base dataset and applies callable transforms to its outputs."""
def __init__(self, base_dataset: torch.utils.data.Dataset, transforms: list[TransformType]) -> None:
super().__init__()
self._base_dataset = base_dataset
self._transforms = transforms
def __len__(self) -> int:
return len(self._base_dataset)
def __getitem__(self, idx: int) -> DataType:
example = self._base_dataset[idx]
for t in self._transforms:
example = t(example)
return example
================================================
FILE: src/invoke_training/_shared/data/samplers/__init__.py
================================================
================================================
FILE: src/invoke_training/_shared/data/samplers/aspect_ratio_bucket_batch_sampler.py
================================================
import copy
import logging
import math
import random
from typing import Iterator
from torch.utils.data import Sampler
from invoke_training._shared.data.utils.aspect_ratio_bucket_manager import AspectRatioBucketManager
from invoke_training._shared.data.utils.resolution import Resolution
AspectRatioBuckets = dict[Resolution, list[int]]
class AspectRatioBucketBatchSampler(Sampler[list[int]]):
"""A batch sampler that adheres to aspect ratio buckets."""
def __init__(
self,
buckets: AspectRatioBuckets,
batch_size: int,
shuffle: bool = False,
seed: int | None = None,
) -> None:
"""Initialize AspectRatioBucketBatchSampler.
For most use cases, initialize via AspectRatioBucketBatchSampler.from_image_sizes(...).
"""
self._buckets = buckets
self._batch_size = batch_size
self._shuffle = shuffle
self._random = random.Random(seed)
def __str__(self) -> str:
buckets = self.get_buckets()
bucket_resolutions = sorted(list(buckets.keys()))
s = ""
for bucket_resolution in bucket_resolutions:
bucket_images = buckets[bucket_resolution]
s += f" {bucket_resolution.to_tuple()}: {len(bucket_images)}\n"
return s
@classmethod
def from_image_sizes(
cls,
bucket_manager: AspectRatioBucketManager,
image_sizes: list[Resolution],
batch_size: int,
shuffle: bool = False,
seed: int | None = None,
):
"""Initialize from an AspectRatioBucketManager and the list of dataset image resolutions."""
buckets = cls._build_bucket_to_index_map(bucket_manager, image_sizes)
return cls(buckets=buckets, batch_size=batch_size, shuffle=shuffle, seed=seed)
@classmethod
def _build_bucket_to_index_map(
cls,
bucket_manager: AspectRatioBucketManager,
image_sizes: list[Resolution],
) -> AspectRatioBuckets:
bucket_to_indexes: AspectRatioBuckets = dict()
for bucket_resolution in bucket_manager.buckets:
bucket_to_indexes[bucket_resolution] = []
for index, image_size in enumerate(image_sizes):
aspect_ratio_bucket = bucket_manager.get_aspect_ratio_bucket(image_size)
bucket_to_indexes[aspect_ratio_bucket].append(index)
return bucket_to_indexes
def get_buckets(self) -> AspectRatioBuckets:
return copy.deepcopy(self._buckets)
def __iter__(self) -> Iterator[list[int]]:
batches: list[list[int]] = []
# TODO(ryand): If self._shuffle == False, should we still shuffle just with a fixed seed every time? If we
# don't shuffle at all then all of the batches from a bucket will be grouped together. If there's a correlation
# between aspect ratio and image content in a dataset, this could result in unevenly distributed image content
# over the dataset.
for bucket_resolution in sorted(list(self._buckets.keys())):
ordered_bucket_images = self._buckets[bucket_resolution].copy()
if self._shuffle:
# Shuffle the images within a bucket.
self._random.shuffle(ordered_bucket_images)
# Prepare batches for a single bucket.
batch_start = 0
while batch_start < len(ordered_bucket_images):
batch_end = min(batch_start + self._batch_size, len(ordered_bucket_images))
batches.append(ordered_bucket_images[batch_start:batch_end])
batch_start += self._batch_size
if self._shuffle:
# We've already shuffled the images within each bucket, now we shuffle the batches.
self._random.shuffle(batches)
yield from batches
def __len__(self) -> int:
num_batches = 0
for bucket_images in self._buckets.values():
num_batches += math.ceil(len(bucket_images) / self._batch_size)
return num_batches
def log_aspect_ratio_buckets(logger: logging.Logger, batch_sampler: AspectRatioBucketBatchSampler):
"""Utility function for logging the aspect ratio buckets."""
if not isinstance(batch_sampler, AspectRatioBucketBatchSampler):
return
log = "Aspect Ratio Buckets:\n"
log += str(batch_sampler)
logger.info(log)
================================================
FILE: src/invoke_training/_shared/data/samplers/batch_offset_sampler.py
================================================
import typing
from torch.utils.data import Sampler
class BatchOffsetSampler(Sampler[int]):
"""A sampler that wraps a batch sampler and applies an offset to all returned batch elements."""
def __init__(self, sampler: Sampler[int], offset: int):
self._sampler = sampler
self._offset = offset
def __iter__(self) -> typing.Iterator[int]:
for batch in self._sampler:
offset_batch = [x + self._offset for x in batch]
yield offset_batch
def __len__(self) -> int:
return len(self._sampler)
================================================
FILE: src/invoke_training/_shared/data/samplers/concat_sampler.py
================================================
import itertools
import typing
from torch.utils.data import Sampler
T_co = typing.TypeVar("T_co", covariant=True)
class ConcatSampler(Sampler[T_co]):
"""A meta-Sampler that concatenates multiple samplers.
Example:
sampler 1: ABCD
sampler 2: EFG
sampler 3: HIJKLM
ConcatSampler: ABCDEFGHIJKLM
"""
def __init__(self, samplers: list[Sampler[T_co] | typing.Iterable[T_co]]) -> None:
self._samplers = samplers
def __iter__(self) -> typing.Iterator[T_co]:
return itertools.chain(*self._samplers)
def __len__(self) -> int:
return sum([len(s) for s in self._samplers])
================================================
FILE: src/invoke_training/_shared/data/samplers/interleaved_sampler.py
================================================
import typing
from torch.utils.data import Sampler
T_co = typing.TypeVar("T_co", covariant=True)
class InterleavedSampler(Sampler[T_co]):
"""A meta-Sampler that interleaves multiple samplers.
The length of this sampler is based on the length of the shortest input sampler. All samplers will contribute the
same number of samples to the interleaved output.
Example:
sampler 1: ABCD
sampler 2: EFG
sampler 3: HIJKLM
interleaved sampler: AEHBFICGJ
"""
def __init__(self, samplers: list[Sampler[T_co] | typing.Iterable[T_co]]) -> None:
self._samplers = samplers
self._min_sampler_len = min([len(s) for s in self._samplers])
def __iter__(self) -> typing.Iterator[T_co]:
sampler_iters = [iter(s) for s in self._samplers]
while True:
samples = []
for sampler_iter in sampler_iters:
try:
samples.append(next(sampler_iter))
except StopIteration:
# The end of the shortest sampler has been reached.
return
yield from samples
def __len__(self) -> int:
return self._min_sampler_len * len(self._samplers)
================================================
FILE: src/invoke_training/_shared/data/samplers/offset_sampler.py
================================================
import typing
from torch.utils.data import Sampler
class OffsetSampler(Sampler[int]):
"""A sampler that wraps another sampler and applies an offset to all returned values."""
def __init__(self, sampler: Sampler[int], offset: int):
self._sampler = sampler
self._offset = offset
def __iter__(self) -> typing.Iterator[int]:
for idx in self._sampler:
yield idx + self._offset
def __len__(self) -> int:
return len(self._sampler)
================================================
FILE: src/invoke_training/_shared/data/transforms/__init__.py
================================================
================================================
FILE: src/invoke_training/_shared/data/transforms/caption_prefix_transform.py
================================================
import typing
class CaptionPrefixTransform:
"""A transform that adds a prefix to all example captions."""
def __init__(self, caption_field_name: str, prefix: str):
self._caption_field_name = caption_field_name
self._prefix = prefix
def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
data[self._caption_field_name] = self._prefix + data[self._caption_field_name]
return data
================================================
FILE: src/invoke_training/_shared/data/transforms/concat_fields_transform.py
================================================
import typing
class ConcatFieldsTransform:
"""A transform that concatenate multiple string fields."""
def __init__(self, src_field_names: list[str], dst_field_name: str, separator: str = " "):
self._src_field_names = src_field_names
self._dst_field_name = dst_field_name
self._separator = separator
def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
result = self._separator.join([data[field_name] for field_name in self._src_field_names])
data[self._dst_field_name] = result
return data
================================================
FILE: src/invoke_training/_shared/data/transforms/constant_field_transform.py
================================================
import typing
class ConstantFieldTransform:
"""A simple transform that adds a constant field to every example."""
def __init__(self, field_name: str, field_value: typing.Any):
self._field_name = field_name
self._field_value = field_value
def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
data[self._field_name] = self._field_value
return data
================================================
FILE: src/invoke_training/_shared/data/transforms/drop_field_transform.py
================================================
import typing
class DropFieldTransform:
"""A simple transform that drops a field from an example."""
def __init__(self, field_to_drop: str):
self._field_to_drop = field_to_drop
def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
if self._field_to_drop in data:
del data[self._field_to_drop]
return data
================================================
FILE: src/invoke_training/_shared/data/transforms/flux_image_transform.py
================================================
import typing
from torchvision import transforms
from torchvision.transforms.functional import crop
from invoke_training._shared.data.utils.aspect_ratio_bucket_manager import AspectRatioBucketManager, Resolution
from invoke_training._shared.data.utils.resize import resize_to_cover
class FluxImageTransform:
"""A transform that prepares and augments images for Flux.1-dev training."""
def __init__(
self,
image_field_names: list[str],
fields_to_normalize_to_range_minus_one_to_one: list[str],
resolution: int | None = 512,
aspect_ratio_bucket_manager: AspectRatioBucketManager | None = None,
random_flip: bool = True,
center_crop: bool = True,
):
"""Initialize FluxImageTransform.
Args:
image_field_names (list[str]): The field names of the images to be transformed.
resolution (int): The image resolution that will be produced. One of `resolution` and
`aspect_ratio_bucket_manager` should be non-None.
aspect_ratio_bucket_manager (AspectRatioBucketManager): The AspectRatioBucketManager used to determine the
target resolution for each image. One of `resolution` and `aspect_ratio_bucket_manager` should be
non-None.
center_crop (bool, optional): If True, crop to the center of the image to achieve the target resolution. If
False, crop at a random location.
random_flip (bool, optional): Whether to apply a random horizontal flip to the images.
"""
self.image_field_names = image_field_names
self.fields_to_normalize_to_range_minus_one_to_one = fields_to_normalize_to_range_minus_one_to_one
self.resolution = resolution
self.aspect_ratio_bucket_manager = aspect_ratio_bucket_manager
self.random_flip = random_flip
self.center_crop = center_crop
def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]: # noqa: C901
image_fields: dict = {}
for field_name in self.image_field_names:
image_fields[field_name] = data[field_name]
# Get the first image to determine original size and resolution
first_image = next(iter(image_fields.values()))
original_size_hw = (first_image.height, first_image.width)
for field_name, image in image_fields.items():
# Determine the target image resolution.
if self.resolution is not None:
resolution = self.resolution
resolution_obj = Resolution(resolution, resolution)
else:
resolution_obj = self.aspect_ratio_bucket_manager.get_aspect_ratio_bucket(
Resolution.parse(original_size_hw)
)
image = resize_to_cover(image, resolution_obj)
# Apply cropping and record top left crop position
if self.center_crop:
top_left_y = max(0, (image.height - resolution_obj.height) // 2)
top_left_x = max(0, (image.width - resolution_obj.width) // 2)
image = transforms.CenterCrop(resolution_obj.to_tuple())(image)
else:
crop_transform = transforms.RandomCrop(resolution_obj.to_tuple())
top_left_y, top_left_x, h, w = crop_transform.get_params(image, resolution_obj.to_tuple())
image = crop(image, top_left_y, top_left_x, resolution_obj.height, resolution_obj.width)
# Apply random flip and update top left crop position accordingly
if self.random_flip:
# TODO: Use a seed for repeatable results
import random
if random.random() < 0.5:
top_left_x = original_size_hw[1] - image.width - top_left_x
image = transforms.RandomHorizontalFlip(p=1.0)(image)
image = transforms.ToTensor()(image)
if field_name in self.fields_to_normalize_to_range_minus_one_to_one:
image_fields[field_name] = transforms.Normalize([0.5], [0.5])(image)
else:
image_fields[field_name] = image
# Store the processed images and metadata
for field_name, image in image_fields.items():
data[field_name] = image
# Add metadata fields expected by VAE caching
data["original_size_hw"] = original_size_hw
data["crop_top_left_yx"] = (top_left_y, top_left_x)
return data
================================================
FILE: src/invoke_training/_shared/data/transforms/load_cache_transform.py
================================================
import typing
from invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache
class LoadCacheTransform:
"""A transform that loads data from a TensorDiskCache."""
def __init__(
self, cache: TensorDiskCache, cache_key_field: str, cache_field_to_output_field: typing.Dict[str, str]
):
"""Initialize LoadCacheTransform.
Args:
cache (TensorDiskCache): The cache to load from.
cache_key_field (str): The name of the field to use as the cache key.
cache_field_to_output_field (typing.Dict[str, str]): A map of field names in the cached data to the field
names where they should be inserted in the example data.
"""
self._cache = cache
self._cache_key_field = cache_key_field
self._cache_field_to_output_field = cache_field_to_output_field
def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
key = data[self._cache_key_field]
cache_data = self._cache.load(key)
for src, dst in self._cache_field_to_output_field.items():
data[dst] = cache_data[src]
return data
================================================
FILE: src/invoke_training/_shared/data/transforms/sd_image_transform.py
================================================
import random
import typing
from torchvision import transforms
from torchvision.transforms.functional import crop
from invoke_training._shared.data.utils.aspect_ratio_bucket_manager import AspectRatioBucketManager, Resolution
from invoke_training._shared.data.utils.resize import resize_to_cover
class SDImageTransform:
"""A transform that prepares and augments images for Stable Diffusion training."""
def __init__(
self,
image_field_names: list[str],
fields_to_normalize_to_range_minus_one_to_one: list[str],
resolution: int | tuple[int, int] | Resolution | None,
aspect_ratio_bucket_manager: AspectRatioBucketManager | None = None,
center_crop: bool = True,
random_flip: bool = False,
orig_size_field_name: str = "original_size_hw",
crop_field_name: str = "crop_top_left_yx",
):
"""Initialize SDImageTransform.
Args:
image_field_names (list[str]): The field names of the images to be transformed.
resolution (Resolution): The image resolution that will be produced. One of `resolution` and
`aspect_ratio_bucket_manager` should be non-None.
aspect_ratio_bucket_manager (AspectRatioBucketManager): The AspectRatioBucketManager used to determine the
target resolution for each image. One of `resolution` and `aspect_ratio_bucket_manager` should be
non-None.
center_crop (bool, optional): If True, crop to the center of the image to achieve the target resolution. If
False, crop at a random location.
random_flip (bool, optional): Whether to apply a random horizontal flip to the images.
"""
self._image_field_names = image_field_names
self._fields_to_normalize_to_range_minus_one_to_one = fields_to_normalize_to_range_minus_one_to_one
if resolution is not None and aspect_ratio_bucket_manager is not None:
raise ValueError("Only one of `resolution` or `aspect_ratio_bucket_manager` should be set.")
if resolution is None and aspect_ratio_bucket_manager is None:
raise ValueError("One of `resolution` or `aspect_ratio_bucket_manager` must be set.")
self._resolution = Resolution.parse(resolution) if resolution is not None else None
self._aspect_ratio_bucket_manager = aspect_ratio_bucket_manager
self._center_crop_enabled = center_crop
self._random_flip_enabled = random_flip
self._flip_transform = transforms.RandomHorizontalFlip(p=1.0)
self._to_tensor_transform = transforms.ToTensor()
# Convert pixel values from range [0, 1.0] to range [-1.0, 1.0].
# Normalize applies the following transform: out = (in - 0.5) / 0.5
self._normalize_image_transform = transforms.Normalize([0.5], [0.5])
self._orig_size_field_name = orig_size_field_name
self._crop_field_name = crop_field_name
def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]: # noqa: C901
# This SDXL image pre-processing logic is adapted from:
# https://github.com/huggingface/diffusers/blob/7b07f9812a58bfa96c06ed8ffe9e6b584286e2fd/examples/text_to_image/train_text_to_image_lora_sdxl.py#L850-L873
image_fields: dict = {}
for field_name in self._image_field_names:
image_fields[field_name] = data[field_name]
sizes = [image.size for image in image_fields.values()]
# All images should have the same size.
assert all(size == sizes[0] for size in sizes)
# Helper function to access the first image, which is sometimes used to infer the shape of all images.
def get_first_image():
return next(iter(image_fields.values()))
original_size_hw = (get_first_image().height, get_first_image().width)
# Determine the target image resolution.
if self._resolution is not None:
resolution = self._resolution
else:
resolution = self._aspect_ratio_bucket_manager.get_aspect_ratio_bucket(Resolution.parse(original_size_hw))
# Resize to cover the target resolution while preserving aspect ratio.
for field_name, image in image_fields.items():
image_fields[field_name] = resize_to_cover(image, resolution)
# Apply cropping, and record top left crop position.
if self._center_crop_enabled:
top_left_y = max(0, (get_first_image().height - resolution.height) // 2)
top_left_x = max(0, (get_first_image().width - resolution.width) // 2)
else:
crop_transform = transforms.RandomCrop(resolution.to_tuple())
top_left_y, top_left_x, h, w = crop_transform.get_params(get_first_image(), resolution.to_tuple())
for field_name, image in image_fields.items():
image_fields[field_name] = crop(image, top_left_y, top_left_x, resolution.height, resolution.width)
# Apply random flip and update top left crop position accordingly.
# TODO(ryand): Use a seed for repeatable results.
if self._random_flip_enabled and random.random() < 0.5:
top_left_x = original_size_hw[1] - get_first_image().width - top_left_x
for field_name, image in image_fields.items():
image_fields[field_name] = self._flip_transform(image)
crop_top_left_yx = (top_left_y, top_left_x)
# Convert to Tensors.
for field_name, image in image_fields.items():
image_fields[field_name] = self._to_tensor_transform(image)
# Normalize to range [-1.0, 1.0].
# HACK(ryand): We should find a better way to determine the normalization range of each image field.
for field_name, image in image_fields.items():
if field_name in self._fields_to_normalize_to_range_minus_one_to_one:
image_fields[field_name] = self._normalize_image_transform(image)
data[self._orig_size_field_name] = original_size_hw
data[self._crop_field_name] = crop_top_left_yx
for field_name, image in image_fields.items():
data[field_name] = image
return data
================================================
FILE: src/invoke_training/_shared/data/transforms/shuffle_caption_transform.py
================================================
import typing
import numpy as np
class ShuffleCaptionTransform:
"""A transform that applies shuffle transformations to character-delimited captions.
Example:
- Original: "unreal engine, render of sci-fi helmet, dramatic lighting"
- Shuffled: "render of sci-fi helmet, unreal engine, dramatic lighting"
"""
def __init__(self, field_name: str, delimiter: str = ",", seed: int = 0):
self._field_name = field_name
self._delimiter = delimiter
self._rng = np.random.default_rng(seed)
def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
caption: str = data[self._field_name]
caption_chunks = caption.split(self._delimiter)
caption_chunks = [s.strip() for s in caption_chunks]
self._rng.shuffle(caption_chunks)
join_str = self._delimiter + " "
data[self._field_name] = join_str.join(caption_chunks)
return data
================================================
FILE: src/invoke_training/_shared/data/transforms/template_caption_transform.py
================================================
import typing
import numpy as np
class TemplateCaptionTransform:
"""A simple transform that constructs a caption for each example by combining a caption template with the
placeholder string.
"""
def __init__(self, field_name: str, placeholder_str: str, caption_templates: list[str], seed: int = 0):
self._field_name = field_name
self._placeholder_str = placeholder_str
self._caption_templates = caption_templates
self._rng = np.random.default_rng(seed)
def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
caption = self._rng.choice(self._caption_templates).format(self._placeholder_str)
# Assert that the template was well-formed such that the placeholder string is in the output caption.
assert self._placeholder_str in caption
data[self._field_name] = caption
return data
================================================
FILE: src/invoke_training/_shared/data/transforms/tensor_disk_cache.py
================================================
import os
import typing
import torch
class TensorDiskCache:
"""A data cache that caches `torch.Tensor`s on disk."""
def __init__(self, cache_dir: str):
super().__init__()
self._cache_dir = cache_dir
os.makedirs(self._cache_dir, exist_ok=True)
def _get_path(self, key: int):
"""Get the cache file path for `key`.
Args:
key (int): The cache key.
Returns:
str: The cache file path.
"""
return os.path.join(self._cache_dir, f"{key}.pt")
def save(self, key: int, data: typing.Dict[str, torch.Tensor]):
"""Save data in the cache.
Raises:
AssertionError: If an entry already exists in the cache for this `key`.
Args:
key (int): The cache key.
data (typing.Dict[str, torch.Tensor]): The data to save.
"""
# torch.save() supports a range of different data types, but it is cleaner if we force everyone to use a dict.
# This allows for more reusable cache loading code.
assert isinstance(data, dict)
save_path = self._get_path(key)
assert not os.path.exists(save_path)
torch.save(data, save_path)
def load(self, key: int) -> typing.Dict[str, torch.Tensor]:
"""Load data from the cache.
Args:
key (int): The cache key to load.
Returns:
typing.Dict[str, torch.Tensor]: Data loaded from the cache.
"""
return torch.load(self._get_path(key))
================================================
FILE: src/invoke_training/_shared/data/utils/__init__.py
================================================
================================================
FILE: src/invoke_training/_shared/data/utils/aspect_ratio_bucket_manager.py
================================================
from invoke_training._shared.data.utils.resolution import Resolution
class AspectRatioBucketManager:
def __init__(self, buckets: set[Resolution]):
self.buckets = buckets
@classmethod
def from_constraints(cls, target_resolution: int, start_dim: int, end_dim: int, divisible_by: int) -> None:
buckets = cls.build_aspect_ratio_buckets(
target_resolution=target_resolution,
start_dim=start_dim,
end_dim=end_dim,
divisible_by=divisible_by,
)
return cls(buckets)
@classmethod
def build_aspect_ratio_buckets(
cls, target_resolution: int, start_dim: int, end_dim: int, divisible_by: int
) -> set[Resolution]:
"""Prepare a set of aspect ratios.
Args:
target_resolution (Resolution): All resolutions in the returned set will aim to have close to
(but <=) `target_resolution * target_resolution` pixels.
start_dim (int):
end_dim (int):
divisible_by (int): All dimensions in the returned set of resolutions will be divisible by `divisible_by`.
Returns:
set[tuple[int, int]]: The aspect ratio bucket resolutions.
"""
# Validate target_resolution.
assert target_resolution % divisible_by == 0
# Validate start_dim, end_dim.
assert start_dim <= end_dim
assert start_dim % divisible_by == 0
assert end_dim % divisible_by == 0
target_size = target_resolution * target_resolution
buckets = set()
height = start_dim
while height <= end_dim:
width = (target_size // height) // divisible_by * divisible_by
buckets.add(Resolution(height, width))
buckets.add(Resolution(width, height))
height += divisible_by
return buckets
def get_aspect_ratio_bucket(self, resolution: Resolution):
"""Get the bucket with the closest aspect ratio to 'resolution'."""
# Note: If this is ever found to be a bottleneck, there is a clearly-more-efficient implementation using bisect.
return min(self.buckets, key=lambda x: abs(x.aspect_ratio() - resolution.aspect_ratio()))
================================================
FILE: src/invoke_training/_shared/data/utils/resize.py
================================================
import math
from PIL.Image import Image
from torchvision import transforms
from invoke_training._shared.data.utils.resolution import Resolution
def resize_to_cover(image: Image, size_to_cover: Resolution) -> Image:
"""Resize image to the smallest size that covers 'size_to_cover' while preserving its aspect ratio.
In other words, achieve the following:
- resized_height >= size_to_cover.height
- resized_width >= size_to_cover.width
- resized_height == size_to_cover.height or resized_width == size_to_cover.width
- 'image' aspect ratio is preserved.
"""
scale_to_height = size_to_cover.height / image.height
scale_to_width = size_to_cover.width / image.width
if scale_to_height > scale_to_width:
resize_height = size_to_cover.height
resize_width = math.ceil(image.width * scale_to_height)
else:
resize_width = size_to_cover.width
resize_height = math.ceil(image.height * scale_to_width)
resize_transform = transforms.Resize(
(resize_height, resize_width), interpolation=transforms.InterpolationMode.BILINEAR
)
return resize_transform(image)
================================================
FILE: src/invoke_training/_shared/data/utils/resolution.py
================================================
from typing import Union
class Resolution:
def __init__(self, height: int, width: int):
self.height = height
self.width = width
@classmethod
def parse(cls, resolution: Union[int, tuple[int, int], "Resolution"]):
"""Initialize a Resolution object from another type."""
if isinstance(resolution, int):
# Assume square resolution.
return cls(resolution, resolution)
elif isinstance(resolution, tuple):
height, width = resolution
return cls(height, width)
elif isinstance(resolution, cls):
return cls(resolution.height, resolution.width)
else:
raise ValueError(f"Unsupported resolution type: '{type(resolution)}'.")
def aspect_ratio(self):
return self.height / self.width
def to_tuple(self) -> tuple[int, int]:
return (self.height, self.width)
def __eq__(self, other: "Resolution") -> bool:
return self.to_tuple() == other.to_tuple()
def __lt__(self, other: "Resolution") -> bool:
return self.to_tuple() < other.to_tuple()
def __hash__(self):
return hash(self.to_tuple())
================================================
FILE: src/invoke_training/_shared/flux/encoding_utils.py
================================================
import logging
from typing import List, Optional, Tuple, Union
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
def get_clip_prompt_embeds(
prompt: Union[str, List[str]],
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
device: torch.device,
num_images_per_prompt: int = 1,
tokenizer_max_length: int = 77,
logger: logging.Logger | None = None,
) -> torch.FloatTensor:
"""Encodes the prompt using CLIP text encoder and returns pooled embeddings."""
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
# Process text input with the tokenizer
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer_max_length,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
# Check if truncation occurred
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1])
if logger is not None:
logger.warning(f"Warning: The following part of your input was truncated: {removed_text}")
# Get prompt embeddings through the text encoder
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
# Duplicate text embeddings for each generation per prompt
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds
def get_t5_prompt_embeds(
prompt: Union[str, List[str]],
tokenizer: T5TokenizerFast,
text_encoder: T5EncoderModel,
device: torch.device,
num_images_per_prompt: int = 1,
tokenizer_max_length: int = 512,
logger: logging.Logger | None = None,
) -> torch.FloatTensor:
"""Encodes the prompt using T5 text encoder."""
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
# Process text input with the tokenizer
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer_max_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
# Check if truncation occurred
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1])
if logger is not None:
logger.warning(f"Warning: The following part of your input was truncated: {removed_text}")
# Get prompt embeddings through the text encoder
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)[0]
dtype = text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# Get shape and duplicate for multiple generations
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds
def handle_lora_scale(
clip_text_encoder: CLIPTextModel,
t5_text_encoder: T5EncoderModel,
lora_scale: Optional[float] = None,
use_peft_backend: bool = False,
):
"""Handles LoRA scale adjustments for text encoders."""
if lora_scale is not None and use_peft_backend:
from peft.utils import scale_lora_layers
# Apply LoRA scaling to text encoders if they exist
if clip_text_encoder is not None:
scale_lora_layers(clip_text_encoder, lora_scale)
if t5_text_encoder is not None:
scale_lora_layers(t5_text_encoder, lora_scale)
return True
return False
def reset_lora_scale(
clip_text_encoder: CLIPTextModel,
t5_text_encoder: T5EncoderModel,
lora_scale: Optional[float] = None,
lora_applied: bool = False,
use_peft_backend: bool = False,
):
"""Resets LoRA scale for text encoders if it was applied."""
if lora_applied and use_peft_backend:
from peft.utils import unscale_lora_layers
# Reset LoRA scaling
if clip_text_encoder is not None:
unscale_lora_layers(clip_text_encoder, lora_scale)
if t5_text_encoder is not None:
unscale_lora_layers(t5_text_encoder, lora_scale)
# A lot of this code was adapted from:
# https://github.com/huggingface/diffusers/blob/ea81a4228d8ff16042c3ccaf61f0e588e60166cd/src/diffusers/pipelines/flux/pipeline_flux.py#L310-L387
def encode_prompt(
prompt: Union[str, List[str]],
prompt_2: Optional[Union[str, List[str]]],
clip_tokenizer: CLIPTokenizer,
t5_tokenizer: T5TokenizerFast,
clip_text_encoder: CLIPTextModel,
t5_text_encoder: T5EncoderModel,
device: torch.device,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
use_peft_backend: bool = False,
clip_tokenizer_max_length: int = 77,
t5_tokenizer_max_length: int = 512,
logger: logging.Logger | None = None,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""
Encodes the prompt using both CLIP and T5 text encoders.
Returns:
Tuple containing:
- T5 text embeddings
- CLIP pooled embeddings
- Text IDs
"""
# Apply LoRA scale if needed
lora_applied = handle_lora_scale(
clip_text_encoder=clip_text_encoder,
t5_text_encoder=t5_text_encoder,
lora_scale=lora_scale,
use_peft_backend=use_peft_backend,
)
# If no pre-generated embeddings, create them
if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
# Get CLIP pooled embeddings
pooled_prompt_embeds = get_clip_prompt_embeds(
prompt=prompt,
tokenizer=clip_tokenizer,
text_encoder=clip_text_encoder,
device=device,
num_images_per_prompt=num_images_per_prompt,
tokenizer_max_length=clip_tokenizer_max_length,
)
# Get T5 text embeddings
prompt_embeds = get_t5_prompt_embeds(
prompt=prompt_2,
tokenizer=t5_tokenizer,
text_encoder=t5_text_encoder,
device=device,
num_images_per_prompt=num_images_per_prompt,
tokenizer_max_length=t5_tokenizer_max_length,
)
# Reset LoRA scale if it was applied
reset_lora_scale(
clip_text_encoder=clip_text_encoder,
t5_text_encoder=t5_text_encoder,
lora_scale=lora_scale,
lora_applied=lora_applied,
use_peft_backend=use_peft_backend,
)
# Create text_ids placeholder for model
dtype = clip_text_encoder.dtype if clip_text_encoder is not None else t5_text_encoder.dtype
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
return prompt_embeds, pooled_prompt_embeds, text_ids
================================================
FILE: src/invoke_training/_shared/flux/lora_checkpoint_utils.py
================================================
# ruff: noqa: N806
import os
from pathlib import Path
import peft
import torch
from diffusers import FluxTransformer2DModel
from transformers import CLIPTextModel
from invoke_training._shared.checkpoints.lora_checkpoint_utils import (
_convert_peft_state_dict_to_kohya_state_dict,
load_multi_model_peft_checkpoint,
save_multi_model_peft_checkpoint,
)
from invoke_training._shared.checkpoints.serialization import save_state_dict
FLUX_TRANSFORMER_TARGET_MODULES = [
# double blocks
"attn.add_k_proj",
"attn.add_q_proj",
"attn.add_v_proj",
"attn.to_add_out",
"attn.to_k",
"attn.to_q",
"attn.to_v",
"attn.to_out.0",
"ff.net.0.proj",
"ff.net.2.0",
"ff_context.net.0.proj",
"ff_context.net.2.0",
# single blocks
"attn.to_k",
"attn.to_q",
"attn.to_v",
"proj_mlp",
"proj_out",
"proj_in",
]
TEXT_ENCODER_TARGET_MODULES = ["fc1", "fc2", "q_proj", "k_proj", "v_proj", "out_proj"]
# Module lists copied from diffusers training script.
# These module lists will produce lighter, less expressive, LoRA models than the non-light versions.
FLUX_TRANSFORMER_TARGET_MODULES_LIGHT = ["to_k", "to_q", "to_v", "to_out.0"]
FLUX_TEXT_ENCODER_TARGET_MODULES_LIGHT = ["q_proj", "k_proj", "v_proj", "out_proj"]
FLUX_PEFT_TRANSFORMER_KEY = "transformer"
FLUX_PEFT_TEXT_ENCODER_1_KEY = "text_encoder_1"
FLUX_PEFT_TEXT_ENCODER_2_KEY = "text_encoder_2"
FLUX_KOHYA_TRANSFORMER_KEY = "lora_unet"
FLUX_KOHYA_TEXT_ENCODER_1_KEY = "lora_clip"
FLUX_KOHYA_TEXT_ENCODER_2_KEY = "lora_t5"
FLUX_PEFT_TO_KOHYA_KEYS = {
FLUX_PEFT_TRANSFORMER_KEY: FLUX_KOHYA_TRANSFORMER_KEY,
FLUX_PEFT_TEXT_ENCODER_1_KEY: FLUX_KOHYA_TEXT_ENCODER_1_KEY,
FLUX_PEFT_TEXT_ENCODER_2_KEY: FLUX_KOHYA_TEXT_ENCODER_2_KEY,
}
def save_flux_peft_checkpoint(
checkpoint_dir: Path | str,
transformer: peft.PeftModel | None,
text_encoder_1: peft.PeftModel | None,
text_encoder_2: peft.PeftModel | None,
):
models = {}
if transformer is not None:
models[FLUX_PEFT_TRANSFORMER_KEY] = transformer
if text_encoder_1 is not None:
models[FLUX_PEFT_TEXT_ENCODER_1_KEY] = text_encoder_1
if text_encoder_2 is not None:
models[FLUX_PEFT_TEXT_ENCODER_2_KEY] = text_encoder_2
save_multi_model_peft_checkpoint(checkpoint_dir=checkpoint_dir, models=models)
def load_flux_peft_checkpoint(
checkpoint_dir: Path | str,
transformer: FluxTransformer2DModel,
text_encoder_1: CLIPTextModel,
text_encoder_2: CLIPTextModel,
is_trainable: bool = False,
):
models = load_multi_model_peft_checkpoint(
checkpoint_dir=checkpoint_dir,
models={
FLUX_PEFT_TRANSFORMER_KEY: transformer,
FLUX_PEFT_TEXT_ENCODER_1_KEY: text_encoder_1,
FLUX_PEFT_TEXT_ENCODER_2_KEY: text_encoder_2,
},
is_trainable=is_trainable,
raise_if_subdir_missing=False,
)
return models[FLUX_PEFT_TRANSFORMER_KEY], models[FLUX_PEFT_TEXT_ENCODER_1_KEY], models[FLUX_PEFT_TEXT_ENCODER_2_KEY]
def save_flux_kohya_checkpoint(
checkpoint_path: Path,
transformer: peft.PeftModel | None,
text_encoder_1: peft.PeftModel | None,
text_encoder_2: peft.PeftModel | None,
):
kohya_prefixes = []
models = []
for kohya_prefix, peft_model in zip(
[FLUX_KOHYA_TRANSFORMER_KEY, FLUX_KOHYA_TEXT_ENCODER_1_KEY], [transformer, text_encoder_1]
):
if peft_model is not None:
kohya_prefixes.append(kohya_prefix)
models.append(peft_model)
kohya_state_dict = _convert_peft_models_to_kohya_state_dict(kohya_prefixes=kohya_prefixes, models=models)
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
save_state_dict(kohya_state_dict, checkpoint_path)
def convert_flux_peft_checkpoint_to_kohya_state_dict(
in_checkpoint_dir: Path,
out_checkpoint_file: Path,
dtype: torch.dtype = torch.float32,
) -> dict[str, torch.Tensor]:
"""Convert Flux PEFT models to a Kohya-format LoRA state dict."""
# Get the immediate subdirectories of the checkpoint directory. We assume that each subdirectory is a PEFT model.
peft_model_dirs = os.listdir(in_checkpoint_dir)
peft_model_dirs = [in_checkpoint_dir / d for d in peft_model_dirs] # Convert to Path objects.
peft_model_dirs = [d for d in peft_model_dirs if d.is_dir()] # Filter out non-directories.
if len(peft_model_dirs) == 0:
raise ValueError(f"No checkpoint files found in directory '{in_checkpoint_dir}'.")
kohya_state_dict = {}
for peft_model_dir in peft_model_dirs:
if peft_model_dir.name in FLUX_PEFT_TO_KOHYA_KEYS:
kohya_prefix = FLUX_PEFT_TO_KOHYA_KEYS[peft_model_dir.name]
else:
raise ValueError(f"Unrecognized checkpoint directory: '{peft_model_dir}'.")
# Note: This logic to load the LoraConfig and weights directly is based on how it is done here:
# https://github.com/huggingface/peft/blob/8665e2b5719faa4e4b91749ddec09442927b53e0/src/peft/peft_model.py#L672-L689
# This may need to be updated in the future to support other adapter types (LoKr, LoHa, etc.).
# Also, I could see this interface breaking in the future.
lora_config = peft.LoraConfig.from_pretrained(peft_model_dir)
lora_weights = peft.utils.load_peft_weights(peft_model_dir, device="cpu")
kohya_state_dict.update(
_convert_peft_state_dict_to_kohya_state_dict(
lora_config=lora_config, peft_state_dict=lora_weights, prefix=kohya_prefix, dtype=dtype
)
)
save_state_dict(kohya_state_dict, out_checkpoint_file)
def _convert_peft_models_to_kohya_state_dict(
kohya_prefixes: list[str], models: list[peft.PeftModel]
) -> dict[str, torch.Tensor]:
kohya_state_dict = {}
default_adapter_name = "default"
for kohya_prefix, peft_model in zip(kohya_prefixes, models, strict=True):
lora_config = peft_model.peft_config[default_adapter_name]
assert isinstance(lora_config, peft.LoraConfig)
state_dict = peft.get_peft_model_state_dict(peft_model, adapter_name=default_adapter_name)
if kohya_prefix == FLUX_KOHYA_TRANSFORMER_KEY:
state_dict = convert_diffusers_to_flux_transformer_checkpoint(state_dict)
kohya_state_dict.update(
_convert_peft_state_dict_to_kohya_state_dict(
lora_config=lora_config,
peft_state_dict=state_dict,
prefix=kohya_prefix,
dtype=torch.float32,
)
)
return kohya_state_dict
def find_matching_key_prefix(state_dict, key_pattern):
"""
Find if any key in the state dictionary matches the given pattern.
Args:
state_dict: The state dictionary to search in
key_pattern: The pattern to look for in keys
Returns:
The matching prefix if found, False otherwise
"""
base_prefix = key_pattern.split(".lora_A")[0].split(".lora_B")[0].split(".weight")[0]
for key in state_dict.keys():
if base_prefix in key:
return base_prefix
return False
def convert_layer_weights(target_dict, source_dict, source_pattern, target_pattern):
"""
Convert weights from source pattern to target pattern if they exist.
Args:
target_dict: Dictionary to store converted weights
source_dict: Source dictionary containing weights
source_pattern: Original key pattern to search for
target_pattern: New key pattern to use
Returns:
Tuple of (updated target_dict, updated source_dict)
"""
if original_key := find_matching_key_prefix(source_dict, source_pattern):
# Find all keys matching the pattern
keys_to_convert = [k for k in source_dict.keys() if original_key in k]
for
gitextract_5kosyax7/
├── .github/
│ └── workflows/
│ ├── deploy.yaml
│ └── test.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── docs/
│ ├── contributing/
│ │ ├── development_environment.md
│ │ ├── directory_structure.md
│ │ ├── documentation.md
│ │ └── tests.md
│ ├── get-started/
│ │ ├── installation.md
│ │ └── quick-start.md
│ ├── guides/
│ │ ├── dataset_formats.md
│ │ ├── model_merge.md
│ │ └── stable_diffusion/
│ │ ├── dpo_lora_sd.md
│ │ ├── gnome_lora_masks_sdxl.md
│ │ ├── robocats_finetune_sdxl.md
│ │ └── textual_inversion_sdxl.md
│ ├── index.md
│ ├── reference/
│ │ └── config/
│ │ ├── index.md
│ │ ├── pipelines/
│ │ │ ├── sd_lora.md
│ │ │ ├── sd_textual_inversion.md
│ │ │ ├── sdxl_finetune.md
│ │ │ ├── sdxl_lora.md
│ │ │ ├── sdxl_lora_and_textual_inversion.md
│ │ │ └── sdxl_textual_inversion.md
│ │ └── shared/
│ │ ├── data/
│ │ │ ├── data_loader_config.md
│ │ │ └── dataset_config.md
│ │ └── optimizer_config.md
│ └── templates/
│ └── python/
│ └── material/
│ └── labels.html
├── mkdocs.yml
├── pyproject.toml
├── sample_data/
│ └── bruce_the_gnome/
│ └── data.jsonl
├── src/
│ └── invoke_training/
│ ├── __init__.py
│ ├── _shared/
│ │ ├── __init__.py
│ │ ├── accelerator/
│ │ │ ├── __init__.py
│ │ │ └── accelerator_utils.py
│ │ ├── checkpoints/
│ │ │ ├── __init__.py
│ │ │ ├── checkpoint_tracker.py
│ │ │ ├── lora_checkpoint_utils.py
│ │ │ └── serialization.py
│ │ ├── data/
│ │ │ ├── ARCHITECTURE.md
│ │ │ ├── __init__.py
│ │ │ ├── data_loaders/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── dreambooth_sd_dataloader.py
│ │ │ │ ├── image_caption_flux_dataloader.py
│ │ │ │ ├── image_caption_sd_dataloader.py
│ │ │ │ ├── image_pair_preference_sd_dataloader.py
│ │ │ │ └── textual_inversion_sd_dataloader.py
│ │ │ ├── datasets/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── build_dataset.py
│ │ │ │ ├── hf_image_caption_dataset.py
│ │ │ │ ├── hf_image_pair_preference_dataset.py
│ │ │ │ ├── image_caption_dir_dataset.py
│ │ │ │ ├── image_caption_jsonl_dataset.py
│ │ │ │ ├── image_dir_dataset.py
│ │ │ │ ├── image_pair_preference_dataset.py
│ │ │ │ └── transform_dataset.py
│ │ │ ├── samplers/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── aspect_ratio_bucket_batch_sampler.py
│ │ │ │ ├── batch_offset_sampler.py
│ │ │ │ ├── concat_sampler.py
│ │ │ │ ├── interleaved_sampler.py
│ │ │ │ └── offset_sampler.py
│ │ │ ├── transforms/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── caption_prefix_transform.py
│ │ │ │ ├── concat_fields_transform.py
│ │ │ │ ├── constant_field_transform.py
│ │ │ │ ├── drop_field_transform.py
│ │ │ │ ├── flux_image_transform.py
│ │ │ │ ├── load_cache_transform.py
│ │ │ │ ├── sd_image_transform.py
│ │ │ │ ├── shuffle_caption_transform.py
│ │ │ │ ├── template_caption_transform.py
│ │ │ │ └── tensor_disk_cache.py
│ │ │ └── utils/
│ │ │ ├── __init__.py
│ │ │ ├── aspect_ratio_bucket_manager.py
│ │ │ ├── resize.py
│ │ │ └── resolution.py
│ │ ├── flux/
│ │ │ ├── encoding_utils.py
│ │ │ ├── lora_checkpoint_utils.py
│ │ │ ├── model_loading_utils.py
│ │ │ └── validation.py
│ │ ├── optimizer/
│ │ │ ├── __init__.py
│ │ │ └── optimizer_utils.py
│ │ ├── stable_diffusion/
│ │ │ ├── __init__.py
│ │ │ ├── base_model_version.py
│ │ │ ├── checkpoint_utils.py
│ │ │ ├── lora_checkpoint_utils.py
│ │ │ ├── min_snr_weighting.py
│ │ │ ├── model_loading_utils.py
│ │ │ ├── textual_inversion.py
│ │ │ ├── tokenize_captions.py
│ │ │ └── validation.py
│ │ ├── tools/
│ │ │ ├── __init__.py
│ │ │ └── generate_images.py
│ │ └── utils/
│ │ ├── import_xformers.py
│ │ └── jsonl.py
│ ├── config/
│ │ ├── __init__.py
│ │ ├── base_pipeline_config.py
│ │ ├── config_base_model.py
│ │ ├── data/
│ │ │ ├── __init__.py
│ │ │ ├── data_loader_config.py
│ │ │ └── dataset_config.py
│ │ ├── optimizer/
│ │ │ ├── __init__.py
│ │ │ └── optimizer_config.py
│ │ └── pipeline_config.py
│ ├── model_merge/
│ │ ├── __init__.py
│ │ ├── extract_lora.py
│ │ ├── merge_models.py
│ │ ├── merge_tasks_to_base.py
│ │ ├── scripts/
│ │ │ ├── extract_lora_from_model_diff.py
│ │ │ ├── merge_lora_into_model.py
│ │ │ ├── merge_models.py
│ │ │ └── merge_task_models_to_base_model.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ ├── normalize_weights.py
│ │ └── parse_model_arg.py
│ ├── pipelines/
│ │ ├── __init__.py
│ │ ├── _experimental/
│ │ │ └── sd_dpo_lora/
│ │ │ ├── config.py
│ │ │ └── train.py
│ │ ├── callbacks.py
│ │ ├── flux/
│ │ │ └── lora/
│ │ │ ├── __init__.py
│ │ │ ├── config.py
│ │ │ └── train.py
│ │ ├── invoke_train.py
│ │ ├── stable_diffusion/
│ │ │ ├── __init__.py
│ │ │ ├── lora/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── config.py
│ │ │ │ └── train.py
│ │ │ └── textual_inversion/
│ │ │ ├── __init__.py
│ │ │ ├── config.py
│ │ │ └── train.py
│ │ └── stable_diffusion_xl/
│ │ ├── __init__.py
│ │ ├── finetune/
│ │ │ ├── __init__.py
│ │ │ ├── config.py
│ │ │ └── train.py
│ │ ├── lora/
│ │ │ ├── __init__.py
│ │ │ ├── config.py
│ │ │ └── train.py
│ │ ├── lora_and_textual_inversion/
│ │ │ ├── __init__.py
│ │ │ ├── config.py
│ │ │ └── train.py
│ │ └── textual_inversion/
│ │ ├── __init__.py
│ │ ├── config.py
│ │ └── train.py
│ ├── sample_configs/
│ │ ├── _experimental/
│ │ │ ├── sd_dpo_lora_pickapic_1x24gb.yaml
│ │ │ └── sd_dpo_lora_refinement_pokemon_1x24gb.yaml
│ │ ├── flux_lora_1x40gb.yaml
│ │ ├── sd_lora_baroque_1x8gb.yaml
│ │ ├── sd_textual_inversion_gnome_1x8gb.yaml
│ │ ├── sdxl_finetune_baroque_1x24gb.yaml
│ │ ├── sdxl_finetune_robocats_1x24gb.yaml
│ │ ├── sdxl_lora_and_ti_gnome_1x24gb.yaml
│ │ ├── sdxl_lora_baroque_1x24gb.yaml
│ │ ├── sdxl_lora_baroque_1x8gb.yaml
│ │ ├── sdxl_lora_masks_gnome_1x24gb.yaml
│ │ ├── sdxl_textual_inversion_gnome_1x24gb.yaml
│ │ └── sdxl_textual_inversion_masks_gnome_1x24gb.yaml
│ ├── scripts/
│ │ ├── __init__.py
│ │ ├── _experimental/
│ │ │ ├── auto_caption/
│ │ │ │ └── auto_caption_images.py
│ │ │ ├── masks/
│ │ │ │ ├── clipseg.py
│ │ │ │ ├── generate_masks.py
│ │ │ │ └── generate_masks_for_jsonl_dataset.py
│ │ │ └── rank_images.py
│ │ ├── convert_sd_lora_to_kohya_format.py
│ │ ├── invoke_generate_images.py
│ │ ├── invoke_train.py
│ │ ├── invoke_train_ui.py
│ │ ├── invoke_visualize_data_loading.py
│ │ └── utils/
│ │ └── image_dir_dataset.py
│ └── ui/
│ ├── __init__.py
│ ├── app.py
│ ├── config_groups/
│ │ ├── __init__.py
│ │ ├── aspect_ratio_bucket_config_group.py
│ │ ├── base_pipeline_config_group.py
│ │ ├── dataset_config_group.py
│ │ ├── flux_lora_config_group.py
│ │ ├── image_caption_sd_data_loader_config_group.py
│ │ ├── optimizer_config_group.py
│ │ ├── sd_lora_config_group.py
│ │ ├── sd_textual_inversion_config_group.py
│ │ ├── sdxl_finetune_config_group.py
│ │ ├── sdxl_lora_and_textual_inversion_config_group.py
│ │ ├── sdxl_lora_config_group.py
│ │ ├── sdxl_textual_inversion_config_group.py
│ │ ├── textual_inversion_sd_data_loader_config_group.py
│ │ └── ui_config_element.py
│ ├── gradio_blocks/
│ │ ├── header.py
│ │ └── pipeline_tab.py
│ ├── index.html
│ ├── pages/
│ │ ├── data_page.py
│ │ └── training_page.py
│ └── utils/
│ ├── prompts.py
│ └── utils.py
└── tests/
└── invoke_training/
├── _shared/
│ ├── __init__.py
│ ├── checkpoints/
│ │ ├── test_checkpoint_tracker.py
│ │ └── test_serialization.py
│ ├── data/
│ │ ├── __init__.py
│ │ ├── data_loaders/
│ │ │ ├── __init__.py
│ │ │ ├── test_dreambooth_sd_dataloader.py
│ │ │ ├── test_image_caption_sd_dataloader.py
│ │ │ ├── test_image_pair_preference_sd_dataloader.py
│ │ │ └── test_textual_inversion_sd_dataloader.py
│ │ ├── dataset_fixtures.py
│ │ ├── datasets/
│ │ │ ├── __init__.py
│ │ │ ├── test_hf_image_caption_dataset.py
│ │ │ ├── test_hf_image_pair_preference_dataset.py
│ │ │ ├── test_image_caption_dir_dataset.py
│ │ │ ├── test_image_caption_jsonl_dataset.py
│ │ │ ├── test_image_dir_dataset.py
│ │ │ ├── test_image_pair_preference_dataset.py
│ │ │ └── test_transform_dataset.py
│ │ ├── samplers/
│ │ │ ├── __init__.py
│ │ │ ├── test_aspect_ratio_bucket_batch_sampler.py
│ │ │ ├── test_batch_offset_sampler.py
│ │ │ ├── test_concat_sampler.py
│ │ │ ├── test_interleaved_sampler.py
│ │ │ └── test_offset_sampler.py
│ │ ├── transforms/
│ │ │ ├── __init__.py
│ │ │ ├── test_caption_prefix_transform.py
│ │ │ ├── test_concat_fields_transform.py
│ │ │ ├── test_constant_field_transform.py
│ │ │ ├── test_drop_field_transform.py
│ │ │ ├── test_load_cache_transform.py
│ │ │ ├── test_sd_image_transform.py
│ │ │ ├── test_shuffle_caption_transform.py
│ │ │ ├── test_template_caption_transform.py
│ │ │ └── test_tensor_disk_cache.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ ├── test_aspect_ratio_bucket_manager.py
│ │ ├── test_resize.py
│ │ └── test_resolution.py
│ ├── stable_diffusion/
│ │ ├── __init__.py
│ │ ├── test_base_model_version.py
│ │ ├── test_lora_checkpoint_utils.py
│ │ ├── test_model_loading_utils.py
│ │ ├── test_textual_inversion.py
│ │ └── ti_embedding_checkpoint_fixture.py
│ └── utils/
│ └── test_jsonl.py
├── config/
│ └── pipelines/
│ └── test_pipeline_config.py
├── model_merge/
│ ├── __init__.py
│ ├── test_merge_models.py
│ ├── test_merge_tasks_to_base.py
│ └── utils.py
└── ui/
└── utils/
└── test_prompts.py
SYMBOL INDEX (581 symbols across 157 files)
FILE: src/invoke_training/_shared/accelerator/accelerator_utils.py
function initialize_accelerator (line 14) | def initialize_accelerator(
function initialize_logging (line 40) | def initialize_logging(logger_name: str, accelerator: Accelerator) -> Mu...
function get_mixed_precision_dtype (line 71) | def get_mixed_precision_dtype(accelerator: Accelerator):
function get_dtype_from_str (line 95) | def get_dtype_from_str(dtype_str: Literal["float16", "bfloat16", "float3...
FILE: src/invoke_training/_shared/checkpoints/checkpoint_tracker.py
class CheckpointTracker (line 6) | class CheckpointTracker:
method __init__ (line 14) | def __init__(
method prune (line 48) | def prune(self, buffer_num: int = 1) -> int:
method get_path (line 83) | def get_path(self, epoch: int, step: int) -> str:
FILE: src/invoke_training/_shared/checkpoints/lora_checkpoint_utils.py
function save_multi_model_peft_checkpoint (line 7) | def save_multi_model_peft_checkpoint(checkpoint_dir: Path | str, models:...
function load_multi_model_peft_checkpoint (line 30) | def load_multi_model_peft_checkpoint(
function _convert_peft_state_dict_to_kohya_state_dict (line 57) | def _convert_peft_state_dict_to_kohya_state_dict(
function _convert_peft_models_to_kohya_state_dict (line 79) | def _convert_peft_models_to_kohya_state_dict(
FILE: src/invoke_training/_shared/checkpoints/serialization.py
function save_state_dict (line 8) | def save_state_dict(state_dict: typing.Dict[str, torch.Tensor], out_file...
function load_state_dict (line 33) | def load_state_dict(in_file: typing.Union[Path, str]) -> typing.Dict[str...
FILE: src/invoke_training/_shared/data/data_loaders/dreambooth_sd_dataloader.py
function build_dreambooth_sd_dataloader (line 25) | def build_dreambooth_sd_dataloader(
FILE: src/invoke_training/_shared/data/data_loaders/image_caption_flux_dataloader.py
function build_image_caption_flux_dataloader (line 33) | def build_image_caption_flux_dataloader( # noqa: C901
FILE: src/invoke_training/_shared/data/data_loaders/image_caption_sd_dataloader.py
function sd_image_caption_collate_fn (line 27) | def sd_image_caption_collate_fn(examples):
function build_aspect_ratio_bucket_manager (line 64) | def build_aspect_ratio_bucket_manager(config: AspectRatioBucketConfig):
function build_image_caption_sd_dataloader (line 73) | def build_image_caption_sd_dataloader( # noqa: C901
FILE: src/invoke_training/_shared/data/data_loaders/image_pair_preference_sd_dataloader.py
function sd_image_pair_preference_collate_fn (line 15) | def sd_image_pair_preference_collate_fn(examples):
function build_image_pair_preference_sd_dataloader (line 49) | def build_image_pair_preference_sd_dataloader(
FILE: src/invoke_training/_shared/data/data_loaders/textual_inversion_sd_dataloader.py
function get_preset_ti_caption_templates (line 33) | def get_preset_ti_caption_templates(preset: Literal["object", "style"]) ...
function build_textual_inversion_sd_dataloader (line 97) | def build_textual_inversion_sd_dataloader( # noqa: C901
FILE: src/invoke_training/_shared/data/datasets/build_dataset.py
function build_hf_hub_image_caption_dataset (line 15) | def build_hf_hub_image_caption_dataset(config: HFHubImageCaptionDatasetC...
function build_image_caption_jsonl_dataset (line 27) | def build_image_caption_jsonl_dataset(config: ImageCaptionJsonlDatasetCo...
function build_image_caption_dir_dataset (line 36) | def build_image_caption_dir_dataset(config: ImageCaptionDirDatasetConfig...
function build_hf_image_pair_preference_dataset (line 43) | def build_hf_image_pair_preference_dataset(
FILE: src/invoke_training/_shared/data/datasets/hf_image_caption_dataset.py
class HFImageCaptionDataset (line 11) | class HFImageCaptionDataset(torch.utils.data.Dataset):
method __init__ (line 18) | def __init__(self, hf_dataset, image_column: str = "image", caption_co...
method from_dir (line 42) | def from_dir(
method from_hub (line 67) | def from_hub(
method get_image_dimensions (line 87) | def get_image_dimensions(self) -> list[Resolution]:
method __len__ (line 101) | def __len__(self) -> int:
method __getitem__ (line 109) | def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:
FILE: src/invoke_training/_shared/data/datasets/hf_image_pair_preference_dataset.py
class HFImagePairPreferenceDataset (line 9) | class HFImagePairPreferenceDataset(torch.utils.data.Dataset):
method __init__ (line 16) | def __init__(
method from_hub (line 77) | def from_hub(
method __len__ (line 101) | def __len__(self) -> int:
method __getitem__ (line 109) | def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:
FILE: src/invoke_training/_shared/data/datasets/image_caption_dir_dataset.py
class ImageCaptionDirDataset (line 10) | class ImageCaptionDirDataset(torch.utils.data.Dataset):
method __init__ (line 13) | def __init__(
method _load_image (line 64) | def _load_image(self, image_path: str) -> Image.Image:
method get_image_dimensions (line 69) | def get_image_dimensions(self) -> list[Resolution]:
method __len__ (line 83) | def __len__(self) -> int:
method __getitem__ (line 86) | def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:
FILE: src/invoke_training/_shared/data/datasets/image_caption_jsonl_dataset.py
class ImageCaptionExample (line 16) | class ImageCaptionExample(BaseModel):
class ImageCaptionJsonlDataset (line 22) | class ImageCaptionJsonlDataset(torch.utils.data.Dataset):
method __init__ (line 25) | def __init__(
method save_jsonl (line 55) | def save_jsonl(self):
method _get_image_path (line 67) | def _get_image_path(self, idx: int) -> str:
method _get_mask_path (line 77) | def _get_mask_path(self, idx: int) -> str:
method _load_image (line 87) | def _load_image(self, image_path: str) -> Image.Image:
method _load_mask (line 92) | def _load_mask(self, mask_path: str) -> Image.Image:
method _load_example (line 95) | def _load_example(self, idx: int) -> dict[str, typing.Any]:
method get_image_dimensions (line 105) | def get_image_dimensions(self) -> list[Resolution]:
method __len__ (line 118) | def __len__(self) -> int:
method __getitem__ (line 121) | def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:
FILE: src/invoke_training/_shared/data/datasets/image_dir_dataset.py
class ImageDirDataset (line 10) | class ImageDirDataset(torch.utils.data.Dataset):
method __init__ (line 13) | def __init__(
method _load_image (line 49) | def _load_image(self, image_path: str) -> Image.Image:
method get_image_dimensions (line 54) | def get_image_dimensions(self) -> list[Resolution]:
method __len__ (line 68) | def __len__(self) -> int:
method __getitem__ (line 71) | def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:
FILE: src/invoke_training/_shared/data/datasets/image_pair_preference_dataset.py
class ImagePairPreferenceDataset (line 11) | class ImagePairPreferenceDataset(torch.utils.data.Dataset):
method __init__ (line 12) | def __init__(self, dataset_dir: str):
method save_metadata (line 19) | def save_metadata(
method __len__ (line 27) | def __len__(self) -> int:
method __getitem__ (line 30) | def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:
FILE: src/invoke_training/_shared/data/datasets/transform_dataset.py
class TransformDataset (line 11) | class TransformDataset(torch.utils.data.Dataset):
method __init__ (line 14) | def __init__(self, base_dataset: torch.utils.data.Dataset, transforms:...
method __len__ (line 19) | def __len__(self) -> int:
method __getitem__ (line 22) | def __getitem__(self, idx: int) -> DataType:
FILE: src/invoke_training/_shared/data/samplers/aspect_ratio_bucket_batch_sampler.py
class AspectRatioBucketBatchSampler (line 15) | class AspectRatioBucketBatchSampler(Sampler[list[int]]):
method __init__ (line 18) | def __init__(
method __str__ (line 34) | def __str__(self) -> str:
method from_image_sizes (line 44) | def from_image_sizes(
method _build_bucket_to_index_map (line 57) | def _build_bucket_to_index_map(
method get_buckets (line 73) | def get_buckets(self) -> AspectRatioBuckets:
method __iter__ (line 76) | def __iter__(self) -> Iterator[list[int]]:
method __len__ (line 103) | def __len__(self) -> int:
function log_aspect_ratio_buckets (line 110) | def log_aspect_ratio_buckets(logger: logging.Logger, batch_sampler: Aspe...
FILE: src/invoke_training/_shared/data/samplers/batch_offset_sampler.py
class BatchOffsetSampler (line 6) | class BatchOffsetSampler(Sampler[int]):
method __init__ (line 9) | def __init__(self, sampler: Sampler[int], offset: int):
method __iter__ (line 13) | def __iter__(self) -> typing.Iterator[int]:
method __len__ (line 18) | def __len__(self) -> int:
FILE: src/invoke_training/_shared/data/samplers/concat_sampler.py
class ConcatSampler (line 9) | class ConcatSampler(Sampler[T_co]):
method __init__ (line 19) | def __init__(self, samplers: list[Sampler[T_co] | typing.Iterable[T_co...
method __iter__ (line 22) | def __iter__(self) -> typing.Iterator[T_co]:
method __len__ (line 25) | def __len__(self) -> int:
FILE: src/invoke_training/_shared/data/samplers/interleaved_sampler.py
class InterleavedSampler (line 8) | class InterleavedSampler(Sampler[T_co]):
method __init__ (line 21) | def __init__(self, samplers: list[Sampler[T_co] | typing.Iterable[T_co...
method __iter__ (line 25) | def __iter__(self) -> typing.Iterator[T_co]:
method __len__ (line 38) | def __len__(self) -> int:
FILE: src/invoke_training/_shared/data/samplers/offset_sampler.py
class OffsetSampler (line 6) | class OffsetSampler(Sampler[int]):
method __init__ (line 9) | def __init__(self, sampler: Sampler[int], offset: int):
method __iter__ (line 13) | def __iter__(self) -> typing.Iterator[int]:
method __len__ (line 17) | def __len__(self) -> int:
FILE: src/invoke_training/_shared/data/transforms/caption_prefix_transform.py
class CaptionPrefixTransform (line 4) | class CaptionPrefixTransform:
method __init__ (line 7) | def __init__(self, caption_field_name: str, prefix: str):
method __call__ (line 11) | def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[...
FILE: src/invoke_training/_shared/data/transforms/concat_fields_transform.py
class ConcatFieldsTransform (line 4) | class ConcatFieldsTransform:
method __init__ (line 7) | def __init__(self, src_field_names: list[str], dst_field_name: str, se...
method __call__ (line 12) | def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[...
FILE: src/invoke_training/_shared/data/transforms/constant_field_transform.py
class ConstantFieldTransform (line 4) | class ConstantFieldTransform:
method __init__ (line 7) | def __init__(self, field_name: str, field_value: typing.Any):
method __call__ (line 11) | def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[...
FILE: src/invoke_training/_shared/data/transforms/drop_field_transform.py
class DropFieldTransform (line 4) | class DropFieldTransform:
method __init__ (line 7) | def __init__(self, field_to_drop: str):
method __call__ (line 10) | def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[...
FILE: src/invoke_training/_shared/data/transforms/flux_image_transform.py
class FluxImageTransform (line 10) | class FluxImageTransform:
method __init__ (line 13) | def __init__(
method __call__ (line 42) | def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[...
FILE: src/invoke_training/_shared/data/transforms/load_cache_transform.py
class LoadCacheTransform (line 6) | class LoadCacheTransform:
method __init__ (line 9) | def __init__(
method __call__ (line 24) | def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[...
FILE: src/invoke_training/_shared/data/transforms/sd_image_transform.py
class SDImageTransform (line 11) | class SDImageTransform:
method __init__ (line 14) | def __init__(
method __call__ (line 59) | def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[...
FILE: src/invoke_training/_shared/data/transforms/shuffle_caption_transform.py
class ShuffleCaptionTransform (line 6) | class ShuffleCaptionTransform:
method __init__ (line 14) | def __init__(self, field_name: str, delimiter: str = ",", seed: int = 0):
method __call__ (line 19) | def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[...
FILE: src/invoke_training/_shared/data/transforms/template_caption_transform.py
class TemplateCaptionTransform (line 6) | class TemplateCaptionTransform:
method __init__ (line 11) | def __init__(self, field_name: str, placeholder_str: str, caption_temp...
method __call__ (line 17) | def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[...
FILE: src/invoke_training/_shared/data/transforms/tensor_disk_cache.py
class TensorDiskCache (line 7) | class TensorDiskCache:
method __init__ (line 10) | def __init__(self, cache_dir: str):
method _get_path (line 16) | def _get_path(self, key: int):
method save (line 25) | def save(self, key: int, data: typing.Dict[str, torch.Tensor]):
method load (line 41) | def load(self, key: int) -> typing.Dict[str, torch.Tensor]:
FILE: src/invoke_training/_shared/data/utils/aspect_ratio_bucket_manager.py
class AspectRatioBucketManager (line 4) | class AspectRatioBucketManager:
method __init__ (line 5) | def __init__(self, buckets: set[Resolution]):
method from_constraints (line 9) | def from_constraints(cls, target_resolution: int, start_dim: int, end_...
method build_aspect_ratio_buckets (line 19) | def build_aspect_ratio_buckets(
method get_aspect_ratio_bucket (line 56) | def get_aspect_ratio_bucket(self, resolution: Resolution):
FILE: src/invoke_training/_shared/data/utils/resize.py
function resize_to_cover (line 9) | def resize_to_cover(image: Image, size_to_cover: Resolution) -> Image:
FILE: src/invoke_training/_shared/data/utils/resolution.py
class Resolution (line 4) | class Resolution:
method __init__ (line 5) | def __init__(self, height: int, width: int):
method parse (line 10) | def parse(cls, resolution: Union[int, tuple[int, int], "Resolution"]):
method aspect_ratio (line 23) | def aspect_ratio(self):
method to_tuple (line 26) | def to_tuple(self) -> tuple[int, int]:
method __eq__ (line 29) | def __eq__(self, other: "Resolution") -> bool:
method __lt__ (line 32) | def __lt__(self, other: "Resolution") -> bool:
method __hash__ (line 35) | def __hash__(self):
FILE: src/invoke_training/_shared/flux/encoding_utils.py
function get_clip_prompt_embeds (line 8) | def get_clip_prompt_embeds(
function get_t5_prompt_embeds (line 55) | def get_t5_prompt_embeds(
function handle_lora_scale (line 101) | def handle_lora_scale(
function reset_lora_scale (line 121) | def reset_lora_scale(
function encode_prompt (line 141) | def encode_prompt(
FILE: src/invoke_training/_shared/flux/lora_checkpoint_utils.py
function save_flux_peft_checkpoint (line 62) | def save_flux_peft_checkpoint(
function load_flux_peft_checkpoint (line 79) | def load_flux_peft_checkpoint(
function save_flux_kohya_checkpoint (line 100) | def save_flux_kohya_checkpoint(
function convert_flux_peft_checkpoint_to_kohya_state_dict (line 121) | def convert_flux_peft_checkpoint_to_kohya_state_dict(
function _convert_peft_models_to_kohya_state_dict (line 158) | def _convert_peft_models_to_kohya_state_dict(
function find_matching_key_prefix (line 185) | def find_matching_key_prefix(state_dict, key_pattern):
function convert_layer_weights (line 204) | def convert_layer_weights(target_dict, source_dict, source_pattern, targ...
function convert_double_transformer_block (line 231) | def convert_double_transformer_block(target_dict, source_dict, prefix=""...
function convert_single_transformer_block (line 333) | def convert_single_transformer_block(target_dict, source_dict, prefix, b...
function convert_embedding_layers (line 384) | def convert_embedding_layers(target_dict, source_dict, prefix, has_guida...
function convert_output_layers (line 436) | def convert_output_layers(target_dict, source_dict, prefix):
function convert_diffusers_to_flux_transformer_checkpoint (line 460) | def convert_diffusers_to_flux_transformer_checkpoint(
FILE: src/invoke_training/_shared/flux/model_loading_utils.py
class PipelineVersionEnum (line 9) | class PipelineVersionEnum(Enum):
function load_pipeline (line 13) | def load_pipeline(
function load_models_flux (line 65) | def load_models_flux(
FILE: src/invoke_training/_shared/flux/validation.py
function generate_validation_images_flux (line 25) | def generate_validation_images_flux( # noqa: C901
FILE: src/invoke_training/_shared/optimizer/optimizer_utils.py
function initialize_optimizer (line 7) | def initialize_optimizer(
FILE: src/invoke_training/_shared/stable_diffusion/base_model_version.py
class BaseModelVersionEnum (line 6) | class BaseModelVersionEnum(Enum):
function get_base_model_version (line 13) | def get_base_model_version(
function check_base_model_version (line 54) | def check_base_model_version(
FILE: src/invoke_training/_shared/stable_diffusion/checkpoint_utils.py
function save_sdxl_diffusers_unet_checkpoint (line 8) | def save_sdxl_diffusers_unet_checkpoint(
function save_sdxl_diffusers_checkpoint (line 25) | def save_sdxl_diffusers_checkpoint(
FILE: src/invoke_training/_shared/stable_diffusion/lora_checkpoint_utils.py
function save_sd_peft_checkpoint (line 66) | def save_sd_peft_checkpoint(
function load_sd_peft_checkpoint (line 78) | def load_sd_peft_checkpoint(
function save_sdxl_peft_checkpoint (line 91) | def save_sdxl_peft_checkpoint(
function load_sdxl_peft_checkpoint (line 108) | def load_sdxl_peft_checkpoint(
function save_sd_kohya_checkpoint (line 129) | def save_sd_kohya_checkpoint(checkpoint_path: Path, unet: peft.PeftModel...
function save_sdxl_kohya_checkpoint (line 143) | def save_sdxl_kohya_checkpoint(
function convert_sd_peft_checkpoint_to_kohya_state_dict (line 165) | def convert_sd_peft_checkpoint_to_kohya_state_dict(
FILE: src/invoke_training/_shared/stable_diffusion/min_snr_weighting.py
function compute_snr (line 5) | def compute_snr(noise_scheduler: DDPMScheduler, timesteps: torch.Tensor):
FILE: src/invoke_training/_shared/stable_diffusion/model_loading_utils.py
class PipelineVersionEnum (line 21) | class PipelineVersionEnum(Enum):
function load_pipeline (line 26) | def load_pipeline(
function from_pretrained_with_variant_fallback (line 75) | def from_pretrained_with_variant_fallback(
function load_models_sd (line 113) | def load_models_sd(
function load_models_sdxl (line 167) | def load_models_sdxl(
FILE: src/invoke_training/_shared/stable_diffusion/textual_inversion.py
function _expand_placeholder_token (line 10) | def _expand_placeholder_token(placeholder_token: str, num_vectors: int =...
function _add_tokens_to_tokenizer (line 23) | def _add_tokens_to_tokenizer(placeholder_tokens: list[str], tokenizer: P...
function expand_placeholders_in_caption (line 37) | def expand_placeholders_in_caption(caption: str, tokenizer: CLIPTokenize...
function initialize_placeholder_tokens_from_initializer_token (line 68) | def initialize_placeholder_tokens_from_initializer_token(
function initialize_placeholder_tokens_from_initial_phrase (line 106) | def initialize_placeholder_tokens_from_initial_phrase(
function initialize_placeholder_tokens_from_initial_embedding (line 131) | def initialize_placeholder_tokens_from_initial_embedding(
function restore_original_embeddings (line 172) | def restore_original_embeddings(
FILE: src/invoke_training/_shared/stable_diffusion/tokenize_captions.py
function tokenize_captions (line 7) | def tokenize_captions(tokenizer: CLIPTokenizer, captions: list[str]) -> ...
FILE: src/invoke_training/_shared/stable_diffusion/validation.py
function generate_validation_images_sd (line 24) | def generate_validation_images_sd( # noqa: C901
function generate_validation_images_sdxl (line 140) | def generate_validation_images_sdxl( # noqa: C901
FILE: src/invoke_training/_shared/tools/generate_images.py
function generate_images (line 15) | def generate_images(
FILE: src/invoke_training/_shared/utils/import_xformers.py
function import_xformers (line 1) | def import_xformers():
FILE: src/invoke_training/_shared/utils/jsonl.py
function load_jsonl (line 6) | def load_jsonl(jsonl_path: Path | str) -> list[Any]:
function save_jsonl (line 15) | def save_jsonl(data: list[Any], jsonl_path: Path | str) -> None:
FILE: src/invoke_training/config/base_pipeline_config.py
class BasePipelineConfig (line 7) | class BasePipelineConfig(ConfigBaseModel):
FILE: src/invoke_training/config/config_base_model.py
class ConfigBaseModel (line 4) | class ConfigBaseModel(BaseModel):
FILE: src/invoke_training/config/data/data_loader_config.py
class AspectRatioBucketConfig (line 10) | class AspectRatioBucketConfig(ConfigBaseModel):
class ImageCaptionSDDataLoaderConfig (line 42) | class ImageCaptionSDDataLoaderConfig(ConfigBaseModel):
class ImageCaptionFluxDataLoaderConfig (line 73) | class ImageCaptionFluxDataLoaderConfig(ConfigBaseModel):
class DreamboothSDDataLoaderConfig (line 104) | class DreamboothSDDataLoaderConfig(ConfigBaseModel):
class TextualInversionSDDataLoaderConfig (line 142) | class TextualInversionSDDataLoaderConfig(ConfigBaseModel):
FILE: src/invoke_training/config/data/dataset_config.py
class HFHubImageCaptionDatasetConfig (line 8) | class HFHubImageCaptionDatasetConfig(ConfigBaseModel):
class ImageCaptionJsonlDatasetConfig (line 33) | class ImageCaptionJsonlDatasetConfig(ConfigBaseModel):
class ImageDirDatasetConfig (line 54) | class ImageDirDatasetConfig(ConfigBaseModel):
class ImageCaptionDirDatasetConfig (line 67) | class ImageCaptionDirDatasetConfig(ConfigBaseModel):
FILE: src/invoke_training/config/optimizer/optimizer_config.py
class AdamOptimizerConfig (line 6) | class AdamOptimizerConfig(ConfigBaseModel):
class ProdigyOptimizerConfig (line 26) | class ProdigyOptimizerConfig(ConfigBaseModel):
FILE: src/invoke_training/model_merge/extract_lora.py
function get_patched_base_weights_from_peft_model (line 10) | def get_patched_base_weights_from_peft_model(peft_model: PeftModel) -> d...
function get_state_dict_diff (line 29) | def get_state_dict_diff(
function extract_lora_from_diffs (line 37) | def extract_lora_from_diffs(
FILE: src/invoke_training/model_merge/merge_models.py
function merge_models (line 10) | def merge_models(
function lerp (line 53) | def lerp(a: torch.Tensor, b: torch.Tensor, weight_a: float) -> torch.Ten...
function slerp (line 58) | def slerp(a: torch.Tensor, b: torch.Tensor, weight_a: float, dot_product...
FILE: src/invoke_training/model_merge/merge_tasks_to_base.py
function merge_tasks_to_base_model (line 9) | def merge_tasks_to_base_model(
FILE: src/invoke_training/model_merge/scripts/extract_lora_from_model_diff.py
class StableDiffusionModel (line 38) | class StableDiffusionModel:
method all_none (line 45) | def all_none(self) -> bool:
function load_model (line 49) | def load_model(
function str_to_device (line 111) | def str_to_device(device_str: Literal["cuda", "cpu"]) -> torch.device:
function state_dict_to_device (line 120) | def state_dict_to_device(state_dict: dict[str, torch.Tensor], device: to...
function extract_lora_from_submodel (line 124) | def extract_lora_from_submodel(
function extract_lora (line 187) | def extract_lora(
function main (line 257) | def main():
FILE: src/invoke_training/model_merge/scripts/merge_lora_into_model.py
function to_invokeai_base_model_type (line 24) | def to_invokeai_base_model_type(model_type: PipelineVersionEnum):
function merge_lora_into_sd_model (line 34) | def merge_lora_into_sd_model(
function parse_lora_model_arg (line 102) | def parse_lora_model_arg(lora_model_arg: str) -> tuple[str, float]:
function main (line 113) | def main():
FILE: src/invoke_training/model_merge/scripts/merge_models.py
class MergeModel (line 16) | class MergeModel:
function run_merge_models (line 22) | def run_merge_models(
function parse_model_args (line 73) | def parse_model_args(models: list[str], weights: list[str]) -> list[Merg...
function main (line 85) | def main():
FILE: src/invoke_training/model_merge/scripts/merge_task_models_to_base_model.py
function run_merge_models (line 14) | def run_merge_models(
function main (line 91) | def main():
FILE: src/invoke_training/model_merge/utils/normalize_weights.py
function normalize_weights (line 1) | def normalize_weights(weights: list[float]) -> list[float]:
FILE: src/invoke_training/model_merge/utils/parse_model_arg.py
function parse_model_arg (line 1) | def parse_model_arg(model: str, delimiter: str = "::") -> tuple[str, str...
FILE: src/invoke_training/pipelines/_experimental/sd_dpo_lora/config.py
class HFHubImagePairPreferenceDatasetConfig (line 10) | class HFHubImagePairPreferenceDatasetConfig(ConfigBaseModel):
class ImagePairPreferenceDatasetConfig (line 16) | class ImagePairPreferenceDatasetConfig(ConfigBaseModel):
class ImagePairPreferenceSDDataLoaderConfig (line 23) | class ImagePairPreferenceSDDataLoaderConfig(ConfigBaseModel):
class SdDirectPreferenceOptimizationLoraConfig (line 50) | class SdDirectPreferenceOptimizationLoraConfig(BasePipelineConfig):
method check_validation_prompts (line 242) | def check_validation_prompts(self):
FILE: src/invoke_training/pipelines/_experimental/sd_dpo_lora/train.py
function _save_sd_lora_checkpoint (line 47) | def _save_sd_lora_checkpoint(
function train_forward_dpo (line 70) | def train_forward_dpo( # noqa: C901
function train (line 191) | def train(config: SdDirectPreferenceOptimizationLoraConfig, callbacks: l...
FILE: src/invoke_training/pipelines/callbacks.py
class ModelType (line 5) | class ModelType(Enum):
class ModelCheckpoint (line 37) | class ModelCheckpoint:
method __init__ (line 40) | def __init__(self, file_path: str, model_type: ModelType):
class TrainingCheckpoint (line 45) | class TrainingCheckpoint:
method __init__ (line 50) | def __init__(self, models: list[ModelCheckpoint], epoch: int, step: int):
class ValidationImage (line 56) | class ValidationImage:
method __init__ (line 57) | def __init__(self, file_path: str, prompt: str, image_idx: int):
class ValidationImages (line 71) | class ValidationImages:
method __init__ (line 72) | def __init__(self, images: list[ValidationImage], epoch: int, step: int):
class PipelineCallbacks (line 85) | class PipelineCallbacks(ABC):
method on_save_checkpoint (line 86) | def on_save_checkpoint(self, checkpoint: TrainingCheckpoint):
method on_save_validation_images (line 89) | def on_save_validation_images(self, images: ValidationImages):
FILE: src/invoke_training/pipelines/flux/lora/config.py
class FluxLoraConfig (line 17) | class FluxLoraConfig(BasePipelineConfig):
FILE: src/invoke_training/pipelines/flux/lora/train.py
function _save_flux_lora_checkpoint (line 47) | def _save_flux_lora_checkpoint(
function _build_data_loader (line 86) | def _build_data_loader(
function cache_text_encoder_outputs (line 109) | def cache_text_encoder_outputs(
function cache_vae_outputs (line 137) | def cache_vae_outputs(cache_dir: str, data_loader: DataLoader, vae: Auto...
function get_sigmas (line 156) | def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch....
function get_noisy_latents (line 168) | def get_noisy_latents(noise_scheduler: FlowMatchEulerDiscreteScheduler, ...
function decode_latents (line 209) | def decode_latents(vae: AutoencoderKL, latents: torch.Tensor):
function train_forward (line 222) | def train_forward( # noqa: C901
function train (line 304) | def train(config: FluxLoraConfig, callbacks: list[PipelineCallbacks] | N...
FILE: src/invoke_training/pipelines/invoke_train.py
function train (line 17) | def train(config: PipelineConfig, callbacks: list[PipelineCallbacks] | N...
FILE: src/invoke_training/pipelines/stable_diffusion/lora/config.py
class SdLoraConfig (line 14) | class SdLoraConfig(BasePipelineConfig):
method check_validation_prompts (line 224) | def check_validation_prompts(self):
FILE: src/invoke_training/pipelines/stable_diffusion/lora/train.py
function _save_sd_lora_checkpoint (line 46) | def _save_sd_lora_checkpoint(
function _build_data_loader (line 80) | def _build_data_loader(
function cache_text_encoder_outputs (line 115) | def cache_text_encoder_outputs(
function cache_vae_outputs (line 143) | def cache_vae_outputs(cache_dir: str, data_loader: DataLoader, vae: Auto...
function train_forward (line 162) | def train_forward( # noqa: C901
function train (line 267) | def train(config: SdLoraConfig, callbacks: list[PipelineCallbacks] | Non...
FILE: src/invoke_training/pipelines/stable_diffusion/textual_inversion/config.py
class SdTextualInversionConfig (line 10) | class SdTextualInversionConfig(BasePipelineConfig):
method check_validation_prompts (line 198) | def check_validation_prompts(self):
FILE: src/invoke_training/pipelines/stable_diffusion/textual_inversion/train.py
function _save_ti_embeddings (line 41) | def _save_ti_embeddings(
function _initialize_placeholder_tokens (line 80) | def _initialize_placeholder_tokens(
function train (line 138) | def train(config: SdTextualInversionConfig, callbacks: list[PipelineCall...
FILE: src/invoke_training/pipelines/stable_diffusion_xl/finetune/config.py
class SdxlFinetuneConfig (line 10) | class SdxlFinetuneConfig(BasePipelineConfig):
method check_validation_prompts (line 166) | def check_validation_prompts(self):
FILE: src/invoke_training/pipelines/stable_diffusion_xl/finetune/train.py
function _save_sdxl_checkpoint (line 44) | def _save_sdxl_checkpoint(
function train (line 96) | def train(config: SdxlFinetuneConfig, callbacks: list[PipelineCallbacks]...
FILE: src/invoke_training/pipelines/stable_diffusion_xl/lora/config.py
class SdxlLoraConfig (line 14) | class SdxlLoraConfig(BasePipelineConfig):
method check_validation_prompts (line 230) | def check_validation_prompts(self):
FILE: src/invoke_training/pipelines/stable_diffusion_xl/lora/train.py
function _save_sdxl_lora_checkpoint (line 49) | def _save_sdxl_lora_checkpoint(
function _build_data_loader (line 88) | def _build_data_loader(
function _encode_prompt (line 131) | def _encode_prompt(text_encoders: list[CLIPPreTrainedModel], prompt_toke...
function cache_text_encoder_outputs (line 158) | def cache_text_encoder_outputs(
function train_forward (line 200) | def train_forward( # noqa: C901
function train (line 335) | def train(config: SdxlLoraConfig, callbacks: list[PipelineCallbacks] | N...
FILE: src/invoke_training/pipelines/stable_diffusion_xl/lora_and_textual_inversion/config.py
class SdxlLoraAndTextualInversionConfig (line 10) | class SdxlLoraAndTextualInversionConfig(BasePipelineConfig):
method check_validation_prompts (line 233) | def check_validation_prompts(self):
FILE: src/invoke_training/pipelines/stable_diffusion_xl/lora_and_textual_inversion/train.py
function _save_sdxl_lora_and_ti_checkpoint (line 50) | def _save_sdxl_lora_and_ti_checkpoint(
function train (line 118) | def train(config: SdxlLoraAndTextualInversionConfig, callbacks: list[Pip...
FILE: src/invoke_training/pipelines/stable_diffusion_xl/textual_inversion/config.py
class SdxlTextualInversionConfig (line 10) | class SdxlTextualInversionConfig(BasePipelineConfig):
method check_validation_prompts (line 200) | def check_validation_prompts(self):
FILE: src/invoke_training/pipelines/stable_diffusion_xl/textual_inversion/train.py
function _save_ti_embeddings (line 44) | def _save_ti_embeddings(
function _initialize_placeholder_tokens (line 93) | def _initialize_placeholder_tokens(
function train (line 164) | def train(config: SdxlTextualInversionConfig, callbacks: list[PipelineCa...
FILE: src/invoke_training/scripts/_experimental/auto_caption/auto_caption_images.py
function select_device_and_dtype (line 14) | def select_device_and_dtype(force_cpu: bool = False) -> tuple[torch.devi...
function process_images (line 24) | def process_images(images: list[Image.Image], prompt: str, moondream, to...
function main (line 35) | def main(
FILE: src/invoke_training/scripts/_experimental/masks/clipseg.py
function load_clipseg_model (line 6) | def load_clipseg_model() -> tuple[CLIPSegProcessor, CLIPSegForImageSegme...
function run_clipseg (line 13) | def run_clipseg(
function select_device (line 58) | def select_device() -> torch.device:
FILE: src/invoke_training/scripts/_experimental/masks/generate_masks.py
function generate_masks (line 13) | def generate_masks(image_dir: str, prompt: str, clipseg_temp: float, bat...
function main (line 53) | def main():
FILE: src/invoke_training/scripts/_experimental/masks/generate_masks_for_jsonl_dataset.py
function collate_fn (line 16) | def collate_fn(examples):
function validate_out_json_path (line 24) | def validate_out_json_path(out_json_path: str | Path):
function generate_masks (line 33) | def generate_masks(
function main (line 92) | def main():
FILE: src/invoke_training/scripts/_experimental/rank_images.py
function parse_args (line 15) | def parse_args():
function clip (line 28) | def clip(val, min_val, max_val):
function main (line 32) | def main():
FILE: src/invoke_training/scripts/convert_sd_lora_to_kohya_format.py
function parse_args (line 11) | def parse_args():
function main (line 36) | def main():
FILE: src/invoke_training/scripts/invoke_generate_images.py
function parse_args (line 8) | def parse_args():
function parse_lora_args (line 95) | def parse_lora_args(lora_args: list[str] | None) -> list[tuple[Path, int]]:
function parse_prompt_file (line 113) | def parse_prompt_file(prompt_file: str) -> list[str]:
function main (line 120) | def main():
FILE: src/invoke_training/scripts/invoke_train.py
function parse_args (line 11) | def parse_args():
function main (line 23) | def main():
FILE: src/invoke_training/scripts/invoke_train_ui.py
function main (line 8) | def main():
FILE: src/invoke_training/scripts/invoke_visualize_data_loading.py
function save_image (line 26) | def save_image(torch_image: torch.Tensor, out_path: Path):
function parse_args (line 49) | def parse_args():
function visualize (line 62) | def visualize(data_loader: DataLoader):
function main (line 89) | def main():
FILE: src/invoke_training/scripts/utils/image_dir_dataset.py
class ImageDirDataset (line 8) | class ImageDirDataset(torch.utils.data.Dataset):
method __init__ (line 11) | def __init__(
method _load_image (line 29) | def _load_image(self, image_path: str) -> Image.Image:
method __len__ (line 34) | def __len__(self) -> int:
method __getitem__ (line 37) | def __getitem__(self, idx: int):
function list_collate_fn (line 43) | def list_collate_fn(examples):
FILE: src/invoke_training/ui/app.py
function build_app (line 12) | def build_app():
FILE: src/invoke_training/ui/config_groups/aspect_ratio_bucket_config_group.py
class AspectRatioBucketConfigGroup (line 9) | class AspectRatioBucketConfigGroup(UIConfigElement):
method __init__ (line 10) | def __init__(self):
method update_ui_components_with_config_data (line 23) | def update_ui_components_with_config_data(
method update_config_with_ui_component_data (line 41) | def update_config_with_ui_component_data(
FILE: src/invoke_training/ui/config_groups/base_pipeline_config_group.py
class BasePipelineConfigGroup (line 10) | class BasePipelineConfigGroup(UIConfigElement):
method __init__ (line 11) | def __init__(self):
method update_ui_components_with_config_data (line 55) | def update_ui_components_with_config_data(self, config: BasePipelineCo...
method update_config_with_ui_component_data (line 94) | def update_config_with_ui_component_data(
FILE: src/invoke_training/ui/config_groups/dataset_config_group.py
class HFHubImageCaptionDatasetConfigGroup (line 22) | class HFHubImageCaptionDatasetConfigGroup(UIConfigElement):
method __init__ (line 23) | def __init__(self):
method update_ui_components_with_config_data (line 43) | def update_ui_components_with_config_data(
method update_config_with_ui_component_data (line 54) | def update_config_with_ui_component_data(
class ImageCaptionJsonlDatasetConfigGroup (line 70) | class ImageCaptionJsonlDatasetConfigGroup(UIConfigElement):
method __init__ (line 71) | def __init__(self):
method update_ui_components_with_config_data (line 90) | def update_ui_components_with_config_data(
method update_config_with_ui_component_data (line 104) | def update_config_with_ui_component_data(
class ImageCaptionDirDatasetConfigGroup (line 119) | class ImageCaptionDirDatasetConfigGroup(UIConfigElement):
method __init__ (line 120) | def __init__(self):
method update_ui_components_with_config_data (line 133) | def update_ui_components_with_config_data(
method update_config_with_ui_component_data (line 141) | def update_config_with_ui_component_data(
class ImageDirDatasetConfigGroup (line 153) | class ImageDirDatasetConfigGroup(UIConfigElement):
method __init__ (line 154) | def __init__(self):
method update_ui_components_with_config_data (line 167) | def update_ui_components_with_config_data(
method update_config_with_ui_component_data (line 175) | def update_config_with_ui_component_data(
class DatasetConfigGroup (line 187) | class DatasetConfigGroup(UIConfigElement):
method __init__ (line 188) | def __init__(self, allowed_types: list[str]):
method _on_type_change (line 224) | def _on_type_change(self, type: str):
method update_ui_components_with_config_data (line 232) | def update_ui_components_with_config_data(
method update_config_with_ui_component_data (line 270) | def update_config_with_ui_component_data(
FILE: src/invoke_training/ui/config_groups/flux_lora_config_group.py
class FluxLoraConfigGroup (line 15) | class FluxLoraConfigGroup(UIConfigElement):
method __init__ (line 16) | def __init__(self):
method get_ui_output_components (line 185) | def get_ui_output_components(self) -> list[gr.components.Component]:
method update_ui_components_with_config_data (line 224) | def update_ui_components_with_config_data(
method update_config_with_ui_component_data (line 285) | def update_config_with_ui_component_data( # noqa: C901
FILE: src/invoke_training/ui/config_groups/image_caption_sd_data_loader_config_group.py
class ImageCaptionSDDataLoaderConfigGroup (line 11) | class ImageCaptionSDDataLoaderConfigGroup(UIConfigElement):
method __init__ (line 12) | def __init__(self):
method update_ui_components_with_config_data (line 63) | def update_ui_components_with_config_data(
method update_config_with_ui_component_data (line 81) | def update_config_with_ui_component_data(
FILE: src/invoke_training/ui/config_groups/optimizer_config_group.py
class AdamOptimizerConfigGroup (line 11) | class AdamOptimizerConfigGroup(UIConfigElement):
method __init__ (line 12) | def __init__(self):
method update_ui_components_with_config_data (line 34) | def update_ui_components_with_config_data(self, config: AdamOptimizerC...
method update_config_with_ui_component_data (line 44) | def update_config_with_ui_component_data(
class ProdigyOptimizerConfigGroup (line 59) | class ProdigyOptimizerConfigGroup(UIConfigElement):
method __init__ (line 60) | def __init__(self):
method update_ui_components_with_config_data (line 77) | def update_ui_components_with_config_data(
method update_config_with_ui_component_data (line 87) | def update_config_with_ui_component_data(
class OptimizerConfigGroup (line 100) | class OptimizerConfigGroup(UIConfigElement):
method __init__ (line 101) | def __init__(self):
method _on_optimizer_type_change (line 119) | def _on_optimizer_type_change(self, optimizer_type: str):
method update_ui_components_with_config_data (line 125) | def update_ui_components_with_config_data(self, config: OptimizerConfi...
method update_config_with_ui_component_data (line 145) | def update_config_with_ui_component_data(self, orig_config: OptimizerC...
FILE: src/invoke_training/ui/config_groups/sd_lora_config_group.py
class SdLoraConfigGroup (line 19) | class SdLoraConfigGroup(UIConfigElement):
method __init__ (line 20) | def __init__(self):
method update_ui_components_with_config_data (line 180) | def update_ui_components_with_config_data(self, config: SdLoraConfig) ...
method update_config_with_ui_component_data (line 218) | def update_config_with_ui_component_data(
FILE: src/invoke_training/ui/config_groups/sd_textual_inversion_config_group.py
class SdTextualInversionConfigGroup (line 19) | class SdTextualInversionConfigGroup(UIConfigElement):
method __init__ (line 20) | def __init__(self):
method update_ui_components_with_config_data (line 180) | def update_ui_components_with_config_data(
method update_config_with_ui_component_data (line 218) | def update_config_with_ui_component_data(
FILE: src/invoke_training/ui/config_groups/sdxl_finetune_config_group.py
class SdxlFinetuneConfigGroup (line 19) | class SdxlFinetuneConfigGroup(UIConfigElement):
method __init__ (line 20) | def __init__(self):
method update_ui_components_with_config_data (line 177) | def update_ui_components_with_config_data(
method update_config_with_ui_component_data (line 215) | def update_config_with_ui_component_data(
FILE: src/invoke_training/ui/config_groups/sdxl_lora_and_textual_inversion_config_group.py
class SdxlLoraAndTextualInversionConfigGroup (line 21) | class SdxlLoraAndTextualInversionConfigGroup(UIConfigElement):
method __init__ (line 22) | def __init__(self):
method update_ui_components_with_config_data (line 225) | def update_ui_components_with_config_data(
method update_config_with_ui_component_data (line 273) | def update_config_with_ui_component_data(
FILE: src/invoke_training/ui/config_groups/sdxl_lora_config_group.py
class SdxlLoraConfigGroup (line 19) | class SdxlLoraConfigGroup(UIConfigElement):
method __init__ (line 20) | def __init__(self):
method update_ui_components_with_config_data (line 186) | def update_ui_components_with_config_data(
method update_config_with_ui_component_data (line 227) | def update_config_with_ui_component_data(
FILE: src/invoke_training/ui/config_groups/sdxl_textual_inversion_config_group.py
class SdxlTextualInversionConfigGroup (line 19) | class SdxlTextualInversionConfigGroup(UIConfigElement):
method __init__ (line 20) | def __init__(self):
method update_ui_components_with_config_data (line 186) | def update_ui_components_with_config_data(
method update_config_with_ui_component_data (line 225) | def update_config_with_ui_component_data(
FILE: src/invoke_training/ui/config_groups/textual_inversion_sd_data_loader_config_group.py
class TextualInversionSDDataLoaderConfigGroup (line 13) | class TextualInversionSDDataLoaderConfigGroup(UIConfigElement):
method __init__ (line 14) | def __init__(self):
method update_ui_components_with_config_data (line 88) | def update_ui_components_with_config_data(
method update_config_with_ui_component_data (line 114) | def update_config_with_ui_component_data(
FILE: src/invoke_training/ui/config_groups/ui_config_element.py
class UIConfigElement (line 6) | class UIConfigElement:
method get_ui_output_components (line 9) | def get_ui_output_components(self) -> list[gr.components.Component]:
method get_ui_input_components (line 19) | def get_ui_input_components(self) -> list[gr.components.Component]:
method update_ui_components_with_config_data (line 29) | def update_ui_components_with_config_data(self, config) -> dict[gr.com...
method update_config_with_ui_component_data (line 33) | def update_config_with_ui_component_data(self, orig_config, ui_data: d...
FILE: src/invoke_training/ui/gradio_blocks/header.py
class Header (line 6) | class Header:
method __init__ (line 7) | def __init__(self):
FILE: src/invoke_training/ui/gradio_blocks/pipeline_tab.py
class PipelineTab (line 11) | class PipelineTab:
method __init__ (line 12) | def __init__(
method on_reset_config_button_click (line 102) | def on_reset_config_button_click(self, file_path: str):
method on_generate_config_button_click (line 129) | def on_generate_config_button_click(self, data: dict):
method on_run_training_button_click (line 162) | def on_run_training_button_click(self):
FILE: src/invoke_training/ui/pages/data_page.py
class DataPage (line 18) | class DataPage:
method __init__ (line 19) | def __init__(self):
method _update_state (line 190) | def _update_state(self, idx: int):
method _on_load_existing_dataset_button_click (line 234) | def _on_load_existing_dataset_button_click(self, data: dict):
method _on_create_dataset_button_click (line 249) | def _on_create_dataset_button_click(self, data: dict):
method _on_change_dataset_button_click (line 269) | def _on_change_dataset_button_click(self):
method _on_save_and_go_button_click (line 274) | def _on_save_and_go_button_click(self, data: dict, idx_change: int):
method _on_save_and_next_button_click (line 288) | def _on_save_and_next_button_click(self, data: dict):
method _on_save_and_prev_button_click (line 291) | def _on_save_and_prev_button_click(self, data: dict):
method _on_cur_example_index_change (line 294) | def _on_cur_example_index_change(self, data: dict):
method _on_add_images_button_click (line 297) | def _on_add_images_button_click(self, data: dict):
method app (line 336) | def app(self):
FILE: src/invoke_training/ui/pages/training_page.py
class TrainingPage (line 33) | class TrainingPage:
method __init__ (line 34) | def __init__(self):
method app (line 152) | def app(self):
method _run_training (line 155) | def _run_training(self, config: PipelineConfig):
FILE: src/invoke_training/ui/utils/prompts.py
function split_pos_neg_prompts (line 4) | def split_pos_neg_prompts(prompt: str) -> tuple[str, str]:
function merge_pos_neg_prompts (line 28) | def merge_pos_neg_prompts(positive_prompt: str, negative_prompt: str) ->...
function convert_ui_prompts_to_pos_neg_prompts (line 47) | def convert_ui_prompts_to_pos_neg_prompts(prompts: str) -> tuple[list[st...
function convert_pos_neg_prompts_to_ui_prompts (line 69) | def convert_pos_neg_prompts_to_ui_prompts(positive_prompts: list[str], n...
FILE: src/invoke_training/ui/utils/utils.py
function get_config_dir_path (line 10) | def get_config_dir_path() -> Path:
function get_assets_dir_path (line 17) | def get_assets_dir_path() -> Path:
function load_config_from_yaml (line 24) | def load_config_from_yaml(file_path: Path | str) -> PipelineConfig:
function get_typing_literal_options (line 35) | def get_typing_literal_options(cls, field_name: str) -> list[str]:
FILE: tests/invoke_training/_shared/checkpoints/test_checkpoint_tracker.py
function test_checkpoint_tracker_get_path_file (line 10) | def test_checkpoint_tracker_get_path_file():
function test_checkpoint_tracker_get_path_directory (line 24) | def test_checkpoint_tracker_get_path_directory():
function test_checkpoint_tracker_bad_extension (line 38) | def test_checkpoint_tracker_bad_extension():
function test_checkpoint_tracker_prune_files (line 46) | def test_checkpoint_tracker_prune_files():
function test_checkpoint_tracker_prune_directories (line 65) | def test_checkpoint_tracker_prune_directories():
function test_checkpoint_tracker_prune_no_max (line 86) | def test_checkpoint_tracker_prune_no_max():
FILE: tests/invoke_training/_shared/checkpoints/test_serialization.py
function test_state_dict_save_and_load_roundtrip (line 14) | def test_state_dict_save_and_load_roundtrip(file_name):
function test_save_state_dict_bad_extension (line 29) | def test_save_state_dict_bad_extension():
function test_load_state_dict_bad_extension (line 35) | def test_load_state_dict_bad_extension():
FILE: tests/invoke_training/_shared/data/data_loaders/test_dreambooth_sd_dataloader.py
function test_build_dreambooth_sd_dataloader (line 12) | def test_build_dreambooth_sd_dataloader(image_dir): # noqa: F811
function test_build_dreambooth_sd_dataloader_no_class_dataset (line 48) | def test_build_dreambooth_sd_dataloader_no_class_dataset(image_dir): # ...
function test_build_dreambooth_sd_dataloader_with_bucketing (line 82) | def test_build_dreambooth_sd_dataloader_with_bucketing(image_dir): # no...
FILE: tests/invoke_training/_shared/data/data_loaders/test_image_caption_sd_dataloader.py
function test_build_image_caption_sd_dataloader (line 12) | def test_build_image_caption_sd_dataloader(image_caption_jsonl): # noqa...
function test_build_image_caption_sd_dataloader_with_masks (line 41) | def test_build_image_caption_sd_dataloader_with_masks(image_caption_json...
FILE: tests/invoke_training/_shared/data/data_loaders/test_image_pair_preference_sd_dataloader.py
function test_build_image_pair_preference_sd_dataloader (line 16) | def test_build_image_pair_preference_sd_dataloader():
FILE: tests/invoke_training/_shared/data/data_loaders/test_textual_inversion_sd_dataloader.py
function test_build_textual_inversion_sd_dataloader (line 15) | def test_build_textual_inversion_sd_dataloader(image_dir): # noqa: F811
function test_build_textual_inversion_sd_dataloader_keep_original_captions (line 50) | def test_build_textual_inversion_sd_dataloader_keep_original_captions(im...
function test_build_textual_inversion_sd_dataloader_with_masks (line 72) | def test_build_textual_inversion_sd_dataloader_with_masks(image_caption_...
FILE: tests/invoke_training/_shared/data/dataset_fixtures.py
function image_dir (line 10) | def image_dir(tmp_path_factory: pytest.TempPathFactory):
function image_caption_dir (line 30) | def image_caption_dir(tmp_path_factory: pytest.TempPathFactory):
function image_caption_jsonl (line 53) | def image_caption_jsonl(tmp_path_factory: pytest.TempPathFactory):
function image_pair_preference_dir (line 88) | def image_pair_preference_dir(tmp_path_factory: pytest.TempPathFactory):
FILE: tests/invoke_training/_shared/data/datasets/test_hf_image_caption_dataset.py
function create_hf_imagefolder_dataset (line 19) | def create_hf_imagefolder_dataset(tmp_dir: Path, num_images: int):
function hf_imagefolder_dir (line 46) | def hf_imagefolder_dir(tmp_path_factory: pytest.TempPathFactory):
function hf_dir_dataset (line 62) | def hf_dir_dataset(hf_imagefolder_dir: Path):
function test_hf_dir_image_caption_dataset_bad_image_column (line 66) | def test_hf_dir_image_caption_dataset_bad_image_column(hf_imagefolder_di...
function test_hf_dir_image_caption_dataset_bad_caption_column (line 74) | def test_hf_dir_image_caption_dataset_bad_caption_column(hf_imagefolder_...
function test_hf_dir_image_caption_dataset_len (line 82) | def test_hf_dir_image_caption_dataset_len(hf_dir_dataset: HFImageCaption...
function test_hf_dir_image_caption_dataset_index_error (line 87) | def test_hf_dir_image_caption_dataset_index_error(hf_dir_dataset: HFImag...
function test_hf_dir_image_caption_dataset_getitem (line 93) | def test_hf_dir_image_caption_dataset_getitem(hf_dir_dataset: HFImageCap...
function test_hf_dir_image_caption_dataset_get_image_dimensions (line 104) | def test_hf_dir_image_caption_dataset_get_image_dimensions(hf_dir_datase...
function test_hf_hub_image_caption_dataset_bad_image_column (line 121) | def test_hf_hub_image_caption_dataset_bad_image_column():
function test_hf_hub_image_caption_dataset_bad_caption_column (line 135) | def test_hf_hub_image_caption_dataset_bad_caption_column():
function hf_hub_dataset (line 148) | def hf_hub_dataset():
function test_hf_hub_image_caption_dataset_index_error (line 157) | def test_hf_hub_image_caption_dataset_index_error(hf_hub_dataset: HFImag...
function test_hf_hub_image_caption_dataset_len (line 165) | def test_hf_hub_image_caption_dataset_len(hf_hub_dataset: HFImageCaption...
function test_hf_hub_image_caption_dataset_getitem (line 174) | def test_hf_hub_image_caption_dataset_getitem(hf_hub_dataset: HFImageCap...
function test_hf_hub_image_caption_dataset_get_image_dimensions (line 187) | def test_hf_hub_image_caption_dataset_get_image_dimensions(hf_hub_datase...
FILE: tests/invoke_training/_shared/data/datasets/test_hf_image_pair_preference_dataset.py
function test_hf_hub_image_caption_dataset_getitem (line 9) | def test_hf_hub_image_caption_dataset_getitem():
function test_hf_hub_image_caption_dataset_len (line 45) | def test_hf_hub_image_caption_dataset_len():
function test_hf_hub_image_caption_dataset_skip_no_preference_len (line 66) | def test_hf_hub_image_caption_dataset_skip_no_preference_len():
FILE: tests/invoke_training/_shared/data/datasets/test_image_caption_dir_dataset.py
function test_image_caption_dir_dataset_len (line 11) | def test_image_caption_dir_dataset_len(image_caption_dir): # noqa: F811
function test_image_caption_dir_dataset_getitem (line 17) | def test_image_caption_dir_dataset_getitem(image_caption_dir): # noqa: ...
function test_image_caption_dir_dataset_keep_in_memory (line 29) | def test_image_caption_dir_dataset_keep_in_memory(image_caption_dir): #...
function test_image_caption_dir_dataset_get_image_dimensions (line 41) | def test_image_caption_dir_dataset_get_image_dimensions(image_caption_di...
function test_image_caption_dir_dataset_missing_caption_file (line 49) | def test_image_caption_dir_dataset_missing_caption_file(tmp_path: Path):...
FILE: tests/invoke_training/_shared/data/datasets/test_image_caption_jsonl_dataset.py
function test_image_caption_jsonl_dataset_len (line 12) | def test_image_caption_jsonl_dataset_len(image_caption_jsonl): # noqa: ...
function test_image_caption_jsonl_dataset_getitem (line 18) | def test_image_caption_jsonl_dataset_getitem(image_caption_jsonl): # no...
function test_image_caption_jsonl_dataset_keep_in_memory (line 32) | def test_image_caption_jsonl_dataset_keep_in_memory(image_caption_jsonl)...
function test_image_caption_jsonl_dataset_get_image_dimensions (line 53) | def test_image_caption_jsonl_dataset_get_image_dimensions(image_caption_...
function test_image_caption_jsonl_dataset_save_jsonl (line 61) | def test_image_caption_jsonl_dataset_save_jsonl(image_caption_jsonl, tmp...
FILE: tests/invoke_training/_shared/data/datasets/test_image_dir_dataset.py
function test_image_dir_dataset_len (line 8) | def test_image_dir_dataset_len(image_dir): # noqa: F811
function test_image_dir_dataset_getitem (line 14) | def test_image_dir_dataset_getitem(image_dir): # noqa: F811
function test_image_dir_dataset_keep_in_memory (line 25) | def test_image_dir_dataset_keep_in_memory(image_dir): # noqa: F811
function test_image_dir_dataset_get_image_dimensions (line 43) | def test_image_dir_dataset_get_image_dimensions(image_dir): # noqa: F811
FILE: tests/invoke_training/_shared/data/datasets/test_image_pair_preference_dataset.py
function test_image_dir_dataset_len (line 8) | def test_image_dir_dataset_len(image_pair_preference_dir): # noqa: F811
function test_image_dir_dataset_getitem (line 14) | def test_image_dir_dataset_getitem(image_pair_preference_dir): # noqa: ...
FILE: tests/invoke_training/_shared/data/datasets/test_transform_dataset.py
function test_transform_dataset_len (line 6) | def test_transform_dataset_len():
function test_transform_dataset_getitem (line 16) | def test_transform_dataset_getitem():
FILE: tests/invoke_training/_shared/data/samplers/test_aspect_ratio_bucket_batch_sampler.py
function assert_shuffled_samples_match (line 8) | def assert_shuffled_samples_match(samples_1, samples_2):
function test_aspect_ratio_bucket_batch_sampler (line 18) | def test_aspect_ratio_bucket_batch_sampler():
function test_aspect_ratio_bucket_batch_sampler_len (line 30) | def test_aspect_ratio_bucket_batch_sampler_len():
function test_aspect_ratio_bucket_batch_sampler_from_image_sizes (line 42) | def test_aspect_ratio_bucket_batch_sampler_from_image_sizes():
function test_aspect_ratio_bucket_batch_sampler_shuffle (line 66) | def test_aspect_ratio_bucket_batch_sampler_shuffle():
function test_aspect_ratio_bucket_batch_sampler_seed (line 81) | def test_aspect_ratio_bucket_batch_sampler_seed():
FILE: tests/invoke_training/_shared/data/samplers/test_batch_offset_sampler.py
function test_batch_offset_sampler (line 6) | def test_batch_offset_sampler():
function test_batch_offset_sampler_len (line 18) | def test_batch_offset_sampler_len():
FILE: tests/invoke_training/_shared/data/samplers/test_concat_sampler.py
function test_concat_sampler (line 4) | def test_concat_sampler():
function test_concat_sampler_batches (line 16) | def test_concat_sampler_batches():
function test_concat_sampler_len (line 28) | def test_concat_sampler_len():
FILE: tests/invoke_training/_shared/data/samplers/test_interleaved_sampler.py
function test_interleaved_sampler (line 4) | def test_interleaved_sampler():
function test_interleaved_sampler_batches (line 16) | def test_interleaved_sampler_batches():
function test_interleaved_sampler_len (line 28) | def test_interleaved_sampler_len():
FILE: tests/invoke_training/_shared/data/samplers/test_offset_sampler.py
function test_offset_sampler (line 6) | def test_offset_sampler():
function test_offset_sampler_len (line 16) | def test_offset_sampler_len():
FILE: tests/invoke_training/_shared/data/transforms/test_caption_prefix_transform.py
function test_caption_prefix_transform (line 4) | def test_caption_prefix_transform():
FILE: tests/invoke_training/_shared/data/transforms/test_concat_fields_transform.py
function test_caption_prefix_transform (line 4) | def test_caption_prefix_transform():
FILE: tests/invoke_training/_shared/data/transforms/test_constant_field_transform.py
function test_constant_field_transform (line 4) | def test_constant_field_transform():
FILE: tests/invoke_training/_shared/data/transforms/test_drop_field_transform.py
function test_drop_field_transform (line 4) | def test_drop_field_transform():
FILE: tests/invoke_training/_shared/data/transforms/test_load_cache_transform.py
function test_load_cache_transform (line 8) | def test_load_cache_transform():
FILE: tests/invoke_training/_shared/data/transforms/test_sd_image_transform.py
function denormalize_image (line 13) | def denormalize_image(img: np.ndarray) -> np.ndarray:
function denormalize_mask (line 32) | def denormalize_mask(mask: np.ndarray) -> np.ndarray:
function test_sd_image_transform_resolution (line 41) | def test_sd_image_transform_resolution():
function test_sd_image_transform_without_mask (line 69) | def test_sd_image_transform_without_mask():
function test_sd_image_transform_range (line 92) | def test_sd_image_transform_range():
function test_sd_image_transform_center_crop (line 124) | def test_sd_image_transform_center_crop():
function test_sd_image_transform_random_crop (line 155) | def test_sd_image_transform_random_crop():
function test_sd_image_transform_center_crop_flip (line 193) | def test_sd_image_transform_center_crop_flip():
function test_sd_image_transform_random_crop_flip (line 228) | def test_sd_image_transform_random_crop_flip():
function test_sd_image_transform_aspect_ratio_bucket_manager (line 270) | def test_sd_image_transform_aspect_ratio_bucket_manager():
function test_sd_image_transform_resolution_input_validation (line 309) | def test_sd_image_transform_resolution_input_validation(
FILE: tests/invoke_training/_shared/data/transforms/test_shuffle_caption_transform.py
function test_shuffle_caption_transform (line 4) | def test_shuffle_caption_transform():
function test_shuffle_caption_transform_no_delimiter (line 15) | def test_shuffle_caption_transform_no_delimiter():
FILE: tests/invoke_training/_shared/data/transforms/test_template_caption_transform.py
function test_template_caption_transform (line 8) | def test_template_caption_transform():
function test_template_caption_transform_seed (line 20) | def test_template_caption_transform_seed():
function test_template_caption_transform_bad_templates (line 55) | def test_template_caption_transform_bad_templates():
FILE: tests/invoke_training/_shared/data/transforms/test_tensor_disk_cache.py
function test_tensor_disk_cache_roundtrip (line 9) | def test_tensor_disk_cache_roundtrip(tmp_path: Path):
function test_tensor_disk_cache_fail_overwrite (line 26) | def test_tensor_disk_cache_fail_overwrite(tmp_path):
FILE: tests/invoke_training/_shared/data/utils/test_aspect_ratio_bucket_manager.py
function test_build_aspect_ratio_buckets_input_validation (line 19) | def test_build_aspect_ratio_buckets_input_validation(
function test_build_aspect_ratio_buckets (line 54) | def test_build_aspect_ratio_buckets(
function test_get_aspect_ratio_bucket (line 79) | def test_get_aspect_ratio_bucket(resolution: Resolution, expected_bucket...
FILE: tests/invoke_training/_shared/data/utils/test_resize.py
function test_resize_to_cover (line 28) | def test_resize_to_cover(in_resolution: Resolution, size_to_cover: Resol...
FILE: tests/invoke_training/_shared/data/utils/test_resolution.py
function test_resolution_parse (line 14) | def test_resolution_parse(input, expected_resolution: Resolution):
FILE: tests/invoke_training/_shared/stable_diffusion/test_base_model_version.py
function test_get_base_model_version (line 21) | def test_get_base_model_version(diffusers_model_name: str, expected_vers...
function test_check_base_model_version_pass (line 39) | def test_check_base_model_version_pass():
function test_check_base_model_version_fail (line 45) | def test_check_base_model_version_fail():
FILE: tests/invoke_training/_shared/stable_diffusion/test_lora_checkpoint_utils.py
function test_convert_sd_peft_checkpoint_to_kohya_state_dict_raise_on_empty_directory (line 10) | def test_convert_sd_peft_checkpoint_to_kohya_state_dict_raise_on_empty_d...
function test_convert_sd_peft_checkpoint_to_kohya_state_dict_raise_on_unexpected_subdirectory (line 17) | def test_convert_sd_peft_checkpoint_to_kohya_state_dict_raise_on_unexpec...
FILE: tests/invoke_training/_shared/stable_diffusion/test_model_loading_utils.py
function test_load_models_sd (line 17) | def test_load_models_sd(sdv1_embedding_path): # noqa: F811
function test_load_models_sdxl (line 38) | def test_load_models_sdxl(sdxl_embedding_path: Path): # noqa: F811
FILE: tests/invoke_training/_shared/stable_diffusion/test_textual_inversion.py
function test_expand_placeholder_token (line 22) | def test_expand_placeholder_token(placeholder_token: str, num_vectors: i...
function test_expand_placeholder_token_raises_on_invalid_num_vectors (line 26) | def test_expand_placeholder_token_raises_on_invalid_num_vectors():
function test_initialize_placeholder_tokens_from_initializer_token (line 32) | def test_initialize_placeholder_tokens_from_initializer_token():
function test_initialize_placeholder_tokens_from_initial_phrase (line 60) | def test_initialize_placeholder_tokens_from_initial_phrase():
function test_initialize_placeholder_tokens_from_initial_embedding (line 87) | def test_initialize_placeholder_tokens_from_initial_embedding(sdv1_embed...
FILE: tests/invoke_training/_shared/stable_diffusion/ti_embedding_checkpoint_fixture.py
function sdv1_embedding_path (line 8) | def sdv1_embedding_path(tmp_path_factory: pytest.TempPathFactory):
function sdxl_embedding_path (line 26) | def sdxl_embedding_path(tmp_path_factory: pytest.TempPathFactory):
FILE: tests/invoke_training/_shared/utils/test_jsonl.py
function test_jsonl_roundtrip (line 6) | def test_jsonl_roundtrip(tmp_path: Path):
FILE: tests/invoke_training/config/pipelines/test_pipeline_config.py
function test_pipeline_config (line 10) | def test_pipeline_config():
FILE: tests/invoke_training/model_merge/test_merge_models.py
function test_merge_models_raises_on_not_enough_state_dicts (line 12) | def test_merge_models_raises_on_not_enough_state_dicts():
function test_merge_models_raises_on_mismatched_weights (line 17) | def test_merge_models_raises_on_mismatched_weights():
function test_merge_models (line 88) | def test_merge_models(
FILE: tests/invoke_training/model_merge/test_merge_tasks_to_base.py
function test_merge_raises_on_mismatched_weights (line 11) | def test_merge_raises_on_mismatched_weights():
function test_merge_ties (line 65) | def test_merge_ties(
FILE: tests/invoke_training/model_merge/utils.py
function state_dicts_are_close (line 4) | def state_dicts_are_close(a: dict[str, torch.Tensor], b: dict[str, torch...
FILE: tests/invoke_training/ui/utils/test_prompts.py
function test_split_pos_neg_prompts (line 21) | def test_split_pos_neg_prompts(prompt: str, expected_positive_prompt: st...
function test_split_pos_neg_prompts_raises_value_error (line 34) | def test_split_pos_neg_prompts_raises_value_error(prompt: str):
function test_convert_ui_prompts_to_pos_neg_prompts (line 73) | def test_convert_ui_prompts_to_pos_neg_prompts(
function test_convert_pos_neg_prompts_to_ui_prompts (line 82) | def test_convert_pos_neg_prompts_to_ui_prompts(
Condensed preview — 246 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (1,009K chars).
[
{
"path": ".github/workflows/deploy.yaml",
"chars": 954,
"preview": "name: Deploy invoke-training docs\n\non:\n push:\n branches:\n - main\n\npermissions:\n contents: write\n\njobs:\n deplo"
},
{
"path": ".github/workflows/test.yaml",
"chars": 1270,
"preview": "name: Test invoke-training\n\non:\n push:\n branches:\n - main\n pull_request:\n workflow_dispatch:\n\njobs:\n build:\n"
},
{
"path": ".gitignore",
"chars": 3169,
"preview": "/output/\n/test_configs/\n/data/\n\n# pyenv\n.python-version\n\n# VSCode\n.vscode/\n\n# Byte-compiled / optimized / DLL files\n__py"
},
{
"path": ".pre-commit-config.yaml",
"chars": 237,
"preview": "# See https://pre-commit.com/ for usage and config.\nrepos:\n- repo: https://github.com/astral-sh/ruff-pre-commit\n # Ruff"
},
{
"path": "LICENSE",
"chars": 11357,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 2179,
"preview": "# invoke-training\n\nA library for training custom Stable Diffusion models (fine-tuning, LoRA training, textual inversion,"
},
{
"path": "docs/contributing/development_environment.md",
"chars": 135,
"preview": "# Development Environment Setup\n\nSee the [developer installation instructions](../get-started/installation.md#developer-"
},
{
"path": "docs/contributing/directory_structure.md",
"chars": 706,
"preview": "# Directory Structure\n\n```bash\ninvoke-training/\n├── README.md\n├── docs/\n├── src/\n│ └── invoke-training/\n│ ├── _s"
},
{
"path": "docs/contributing/documentation.md",
"chars": 225,
"preview": "# Documentation\n\nThe documentation site is generated using [mkdocs](https://www.mkdocs.org/) and [mkdocstrings-python](h"
},
{
"path": "docs/contributing/tests.md",
"chars": 393,
"preview": "# Tests\n\nRun all unit tests with:\n\n```bash\npytest tests/\n```\n\nThere are some test 'markers' defined in [pyproject.toml]("
},
{
"path": "docs/get-started/installation.md",
"chars": 2348,
"preview": "# Installation\n\n## Requirements\n\n1. Python 3.10, 3.11 and 3.12 are currently supported. Check your Python version by run"
},
{
"path": "docs/get-started/quick-start.md",
"chars": 3352,
"preview": "# Quick Start\n\n`invoke-training` has both a GUI and a CLI (for advanced users). The instructions for getting started wit"
},
{
"path": "docs/guides/dataset_formats.md",
"chars": 3778,
"preview": "# Dataset Formats\n\n`invoke-training` supports the following dataset formats:\n\n- `IMAGE_CAPTION_JSONL_DATASET`: A local i"
},
{
"path": "docs/guides/model_merge.md",
"chars": 1824,
"preview": "# Model Merging\n\n`invoke-training` provides utility scripts for several common model merging workflows. This page contai"
},
{
"path": "docs/guides/stable_diffusion/dpo_lora_sd.md",
"chars": 3900,
"preview": "# (Experimental) Diffusion DPO - SD\n\n!!! tip \"Experimental\"\n The Diffusion Direct Preference Optimization training pi"
},
{
"path": "docs/guides/stable_diffusion/gnome_lora_masks_sdxl.md",
"chars": 4972,
"preview": "# LoRA with Masks - SDXL\n\nThis tutorial explains how to prepare masks for an image dataset and then use that dataset to "
},
{
"path": "docs/guides/stable_diffusion/robocats_finetune_sdxl.md",
"chars": 7415,
"preview": "# Finetune - SDXL\n\nThis tutorial explains how to do a full finetune training run on a [Stable Diffusion XL](https://hugg"
},
{
"path": "docs/guides/stable_diffusion/textual_inversion_sdxl.md",
"chars": 3804,
"preview": "# Textual Inversion - SDXL\n\nThis tutorial walks through a [Textual Inversion](https://arxiv.org/abs/2208.01618) training"
},
{
"path": "docs/index.md",
"chars": 693,
"preview": "# invoke-training\n\nA library for training custom Stable Diffusion models (fine-tuning, LoRA training, textual inversion,"
},
{
"path": "docs/reference/config/index.md",
"chars": 373,
"preview": "# Config Reference\n\nThis section contains reference documentation for the `invoke-training` configuration schema (i.e. d"
},
{
"path": "docs/reference/config/pipelines/sd_lora.md",
"chars": 478,
"preview": "# `SdLoraConfig`\n\n<!-- To control the member order, we first list out the members whose order we care about, then we lis"
},
{
"path": "docs/reference/config/pipelines/sd_textual_inversion.md",
"chars": 541,
"preview": "# `SdTextualInversionConfig`\n\n<!-- To control the member order, we first list out the members whose order we care about,"
},
{
"path": "docs/reference/config/pipelines/sdxl_finetune.md",
"chars": 511,
"preview": "# `SdxlFinetuneConfig`\n\n<!-- To control the member order, we first list out the members whose order we care about, then "
},
{
"path": "docs/reference/config/pipelines/sdxl_lora.md",
"chars": 490,
"preview": "# `SdxlLoraConfig`\n\n<!-- To control the member order, we first list out the members whose order we care about, then we l"
},
{
"path": "docs/reference/config/pipelines/sdxl_lora_and_textual_inversion.md",
"chars": 591,
"preview": "# `SdxlLoraAndTextualInversionConfig`\n\n<!-- To control the member order, we first list out the members whose order we ca"
},
{
"path": "docs/reference/config/pipelines/sdxl_textual_inversion.md",
"chars": 988,
"preview": "# `SdxlTextualInversionConfig`\n\nBelow is a sample yaml config file for Textual Inversion SDXL training ([raw file](https"
},
{
"path": "docs/reference/config/shared/data/data_loader_config.md",
"chars": 104,
"preview": "::: invoke_training.config.data.data_loader_config\n options:\n filters:\n - \"!^model_config\"\n"
},
{
"path": "docs/reference/config/shared/data/dataset_config.md",
"chars": 99,
"preview": "::: invoke_training.config.data.dataset_config\n options:\n filters:\n - \"!^model_config\""
},
{
"path": "docs/reference/config/shared/optimizer_config.md",
"chars": 107,
"preview": "::: invoke_training.config.optimizer.optimizer_config\n options:\n filters:\n - \"!^model_config\"\n"
},
{
"path": "docs/templates/python/material/labels.html",
"chars": 266,
"preview": "<!--\n This file is intentionally empty. It overrides the default contents of\n https://github.com/mkdocstrings/pyth"
},
{
"path": "mkdocs.yml",
"chars": 2583,
"preview": "site_name: invoke-training\nsite_url: https://invoke-ai.github.io/invoke-training/\n\nrepo_name: invoke-ai/invoke-training\n"
},
{
"path": "pyproject.toml",
"chars": 2104,
"preview": "[build-system]\nrequires = [\"setuptools>=65.5\", \"pip>=22.3\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"i"
},
{
"path": "sample_data/bruce_the_gnome/data.jsonl",
"chars": 451,
"preview": "{\"image\": \"001.png\", \"text\": \"A stuffed gnome sits on a wooden floor, facing right with a gray couch in the background.\""
},
{
"path": "src/invoke_training/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/invoke_training/_shared/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/invoke_training/_shared/accelerator/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/invoke_training/_shared/accelerator/accelerator_utils.py",
"chars": 3573,
"preview": "import logging\nimport os\nfrom typing import Literal\n\nimport datasets\nimport diffusers\nimport torch\nimport transformers\nf"
},
{
"path": "src/invoke_training/_shared/checkpoints/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/invoke_training/_shared/checkpoints/checkpoint_tracker.py",
"chars": 3876,
"preview": "import os\nimport shutil\nimport typing\n\n\nclass CheckpointTracker:\n \"\"\"A utility class for managing checkpoint paths.\n\n"
},
{
"path": "src/invoke_training/_shared/checkpoints/lora_checkpoint_utils.py",
"chars": 3805,
"preview": "from pathlib import Path\n\nimport peft\nimport torch\n\n\ndef save_multi_model_peft_checkpoint(checkpoint_dir: Path | str, mo"
},
{
"path": "src/invoke_training/_shared/checkpoints/serialization.py",
"chars": 1883,
"preview": "import typing\nfrom pathlib import Path\n\nimport safetensors.torch\nimport torch\n\n\ndef save_state_dict(state_dict: typing.D"
},
{
"path": "src/invoke_training/_shared/data/ARCHITECTURE.md",
"chars": 1284,
"preview": "# Dataset Architecture\nDataset handling is split into 3 layers of abstraction: Datasets, Transforms, and DataLoaders. Ea"
},
{
"path": "src/invoke_training/_shared/data/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/invoke_training/_shared/data/data_loaders/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/invoke_training/_shared/data/data_loaders/dreambooth_sd_dataloader.py",
"chars": 8355,
"preview": "import typing\n\nfrom torch.utils.data import ConcatDataset, DataLoader\nfrom torch.utils.data.sampler import RandomSampler"
},
{
"path": "src/invoke_training/_shared/data/data_loaders/image_caption_flux_dataloader.py",
"chars": 6423,
"preview": "import typing\n\nfrom torch.utils.data import DataLoader\n\nfrom invoke_training._shared.data.data_loaders.image_caption_sd_"
},
{
"path": "src/invoke_training/_shared/data/data_loaders/image_caption_sd_dataloader.py",
"chars": 8199,
"preview": "import typing\n\nimport torch\nfrom torch.utils.data import DataLoader\n\nfrom invoke_training._shared.data.datasets.build_da"
},
{
"path": "src/invoke_training/_shared/data/data_loaders/image_pair_preference_sd_dataloader.py",
"chars": 5973,
"preview": "import typing\n\nimport torch\nfrom torch.utils.data import DataLoader\n\nfrom invoke_training._shared.data.datasets.build_da"
},
{
"path": "src/invoke_training/_shared/data/data_loaders/textual_inversion_sd_dataloader.py",
"chars": 10371,
"preview": "from typing import Literal, Optional\n\nfrom torch.utils.data import DataLoader\n\nfrom invoke_training._shared.data.data_lo"
},
{
"path": "src/invoke_training/_shared/data/datasets/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/invoke_training/_shared/data/datasets/build_dataset.py",
"chars": 2930,
"preview": "from datasets import VerificationMode\n\nfrom invoke_training._shared.data.datasets.hf_image_caption_dataset import HFImag"
},
{
"path": "src/invoke_training/_shared/data/datasets/hf_image_caption_dataset.py",
"chars": 5046,
"preview": "import os\nimport typing\n\nimport datasets\nimport torch.utils.data\nfrom PIL.Image import Image\n\nfrom invoke_training._shar"
},
{
"path": "src/invoke_training/_shared/data/datasets/hf_image_pair_preference_dataset.py",
"chars": 4649,
"preview": "import io\nimport typing\n\nimport datasets\nimport torch.utils.data\nfrom PIL import Image\n\n\nclass HFImagePairPreferenceData"
},
{
"path": "src/invoke_training/_shared/data/datasets/image_caption_dir_dataset.py",
"chars": 3753,
"preview": "import os\nimport typing\n\nimport torch.utils.data\nfrom PIL import Image\n\nfrom invoke_training._shared.data.utils.resoluti"
},
{
"path": "src/invoke_training/_shared/data/datasets/image_caption_jsonl_dataset.py",
"chars": 4708,
"preview": "import typing\nfrom pathlib import Path\n\nimport torch.utils.data\nfrom PIL import Image\nfrom pydantic import BaseModel\n\nfr"
},
{
"path": "src/invoke_training/_shared/data/datasets/image_dir_dataset.py",
"chars": 2883,
"preview": "import os\nimport typing\n\nimport torch.utils.data\nfrom PIL import Image\n\nfrom invoke_training._shared.data.utils.resoluti"
},
{
"path": "src/invoke_training/_shared/data/datasets/image_pair_preference_dataset.py",
"chars": 1549,
"preview": "import os\nimport typing\nfrom pathlib import Path\n\nimport torch.utils.data\nfrom PIL import Image\n\nfrom invoke_training._s"
},
{
"path": "src/invoke_training/_shared/data/datasets/transform_dataset.py",
"chars": 834,
"preview": "import typing\n\nimport torch.utils.data\n\n# The data type expected to be produced by the base dataset and handled by trans"
},
{
"path": "src/invoke_training/_shared/data/samplers/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/invoke_training/_shared/data/samplers/aspect_ratio_bucket_batch_sampler.py",
"chars": 4341,
"preview": "import copy\nimport logging\nimport math\nimport random\nfrom typing import Iterator\n\nfrom torch.utils.data import Sampler\n\n"
},
{
"path": "src/invoke_training/_shared/data/samplers/batch_offset_sampler.py",
"chars": 560,
"preview": "import typing\n\nfrom torch.utils.data import Sampler\n\n\nclass BatchOffsetSampler(Sampler[int]):\n \"\"\"A sampler that wrap"
},
{
"path": "src/invoke_training/_shared/data/samplers/concat_sampler.py",
"chars": 685,
"preview": "import itertools\nimport typing\n\nfrom torch.utils.data import Sampler\n\nT_co = typing.TypeVar(\"T_co\", covariant=True)\n\n\ncl"
},
{
"path": "src/invoke_training/_shared/data/samplers/interleaved_sampler.py",
"chars": 1264,
"preview": "import typing\n\nfrom torch.utils.data import Sampler\n\nT_co = typing.TypeVar(\"T_co\", covariant=True)\n\n\nclass InterleavedSa"
},
{
"path": "src/invoke_training/_shared/data/samplers/offset_sampler.py",
"chars": 490,
"preview": "import typing\n\nfrom torch.utils.data import Sampler\n\n\nclass OffsetSampler(Sampler[int]):\n \"\"\"A sampler that wraps ano"
},
{
"path": "src/invoke_training/_shared/data/transforms/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/invoke_training/_shared/data/transforms/caption_prefix_transform.py",
"chars": 459,
"preview": "import typing\n\n\nclass CaptionPrefixTransform:\n \"\"\"A transform that adds a prefix to all example captions.\"\"\"\n\n def"
},
{
"path": "src/invoke_training/_shared/data/transforms/concat_fields_transform.py",
"chars": 589,
"preview": "import typing\n\n\nclass ConcatFieldsTransform:\n \"\"\"A transform that concatenate multiple string fields.\"\"\"\n\n def __i"
},
{
"path": "src/invoke_training/_shared/data/transforms/constant_field_transform.py",
"chars": 429,
"preview": "import typing\n\n\nclass ConstantFieldTransform:\n \"\"\"A simple transform that adds a constant field to every example.\"\"\"\n"
},
{
"path": "src/invoke_training/_shared/data/transforms/drop_field_transform.py",
"chars": 391,
"preview": "import typing\n\n\nclass DropFieldTransform:\n \"\"\"A simple transform that drops a field from an example.\"\"\"\n\n def __in"
},
{
"path": "src/invoke_training/_shared/data/transforms/flux_image_transform.py",
"chars": 4527,
"preview": "import typing\n\nfrom torchvision import transforms\nfrom torchvision.transforms.functional import crop\n\nfrom invoke_traini"
},
{
"path": "src/invoke_training/_shared/data/transforms/load_cache_transform.py",
"chars": 1187,
"preview": "import typing\n\nfrom invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache\n\n\nclass LoadCacheTr"
},
{
"path": "src/invoke_training/_shared/data/transforms/sd_image_transform.py",
"chars": 6214,
"preview": "import random\nimport typing\n\nfrom torchvision import transforms\nfrom torchvision.transforms.functional import crop\n\nfrom"
},
{
"path": "src/invoke_training/_shared/data/transforms/shuffle_caption_transform.py",
"chars": 955,
"preview": "import typing\n\nimport numpy as np\n\n\nclass ShuffleCaptionTransform:\n \"\"\"A transform that applies shuffle transformatio"
},
{
"path": "src/invoke_training/_shared/data/transforms/template_caption_transform.py",
"chars": 908,
"preview": "import typing\n\nimport numpy as np\n\n\nclass TemplateCaptionTransform:\n \"\"\"A simple transform that constructs a caption "
},
{
"path": "src/invoke_training/_shared/data/transforms/tensor_disk_cache.py",
"chars": 1525,
"preview": "import os\nimport typing\n\nimport torch\n\n\nclass TensorDiskCache:\n \"\"\"A data cache that caches `torch.Tensor`s on disk.\""
},
{
"path": "src/invoke_training/_shared/data/utils/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/invoke_training/_shared/data/utils/aspect_ratio_bucket_manager.py",
"chars": 2228,
"preview": "from invoke_training._shared.data.utils.resolution import Resolution\n\n\nclass AspectRatioBucketManager:\n def __init__("
},
{
"path": "src/invoke_training/_shared/data/utils/resize.py",
"chars": 1148,
"preview": "import math\n\nfrom PIL.Image import Image\nfrom torchvision import transforms\n\nfrom invoke_training._shared.data.utils.res"
},
{
"path": "src/invoke_training/_shared/data/utils/resolution.py",
"chars": 1176,
"preview": "from typing import Union\n\n\nclass Resolution:\n def __init__(self, height: int, width: int):\n self.height = heig"
},
{
"path": "src/invoke_training/_shared/flux/encoding_utils.py",
"chars": 7754,
"preview": "import logging\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nfrom transformers import CLIPTextModel, CLI"
},
{
"path": "src/invoke_training/_shared/flux/lora_checkpoint_utils.py",
"chars": 19688,
"preview": "# ruff: noqa: N806\nimport os\nfrom pathlib import Path\n\nimport peft\nimport torch\nfrom diffusers import FluxTransformer2DM"
},
{
"path": "src/invoke_training/_shared/flux/model_loading_utils.py",
"chars": 6226,
"preview": "import logging\nfrom enum import Enum\n\nimport torch\nfrom diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler,"
},
{
"path": "src/invoke_training/_shared/flux/validation.py",
"chars": 5099,
"preview": "import logging\nimport os\n\nimport numpy as np\nimport torch\nimport torch.utils.data\nfrom accelerate import Accelerator\nfro"
},
{
"path": "src/invoke_training/_shared/optimizer/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/invoke_training/_shared/optimizer/optimizer_utils.py",
"chars": 1513,
"preview": "import torch\nfrom prodigyopt import Prodigy\n\nfrom invoke_training.config.optimizer.optimizer_config import AdamOptimizer"
},
{
"path": "src/invoke_training/_shared/stable_diffusion/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/invoke_training/_shared/stable_diffusion/base_model_version.py",
"chars": 2986,
"preview": "from enum import Enum\n\nfrom transformers import PretrainedConfig\n\n\nclass BaseModelVersionEnum(Enum):\n STABLE_DIFFUSIO"
},
{
"path": "src/invoke_training/_shared/stable_diffusion/checkpoint_utils.py",
"chars": 2264,
"preview": "from pathlib import Path\n\nimport torch\nfrom diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, UN"
},
{
"path": "src/invoke_training/_shared/stable_diffusion/lora_checkpoint_utils.py",
"chars": 7415,
"preview": "import os\nfrom pathlib import Path\n\nimport peft\nimport torch\nfrom diffusers import UNet2DConditionModel\nfrom transformer"
},
{
"path": "src/invoke_training/_shared/stable_diffusion/min_snr_weighting.py",
"chars": 1534,
"preview": "import torch\nfrom diffusers import DDPMScheduler\n\n\ndef compute_snr(noise_scheduler: DDPMScheduler, timesteps: torch.Tens"
},
{
"path": "src/invoke_training/_shared/stable_diffusion/model_loading_utils.py",
"chars": 8016,
"preview": "import logging\nimport os\nimport typing\nfrom enum import Enum\n\nimport torch\nfrom diffusers import (\n AutoencoderKL,\n "
},
{
"path": "src/invoke_training/_shared/stable_diffusion/textual_inversion.py",
"chars": 9269,
"preview": "import logging\n\nimport torch\nfrom accelerate import Accelerator\nfrom transformers import CLIPTextModel, CLIPTokenizer, P"
},
{
"path": "src/invoke_training/_shared/stable_diffusion/tokenize_captions.py",
"chars": 907,
"preview": "import torch\nfrom transformers import CLIPTokenizer\n\nfrom invoke_training._shared.stable_diffusion.textual_inversion imp"
},
{
"path": "src/invoke_training/_shared/stable_diffusion/validation.py",
"chars": 10055,
"preview": "import logging\nimport os\n\nimport numpy as np\nimport torch\nimport torch.utils.data\nfrom accelerate import Accelerator\nfro"
},
{
"path": "src/invoke_training/_shared/tools/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/invoke_training/_shared/tools/generate_images.py",
"chars": 4326,
"preview": "import os\nfrom pathlib import Path\nfrom typing import Optional\n\nimport torch\nfrom tqdm import tqdm\n\nfrom invoke_training"
},
{
"path": "src/invoke_training/_shared/utils/import_xformers.py",
"chars": 291,
"preview": "def import_xformers():\n try:\n import xformers # noqa: F401\n except ImportError:\n raise ImportError("
},
{
"path": "src/invoke_training/_shared/utils/jsonl.py",
"chars": 525,
"preview": "import json\nfrom pathlib import Path\nfrom typing import Any\n\n\ndef load_jsonl(jsonl_path: Path | str) -> list[Any]:\n \""
},
{
"path": "src/invoke_training/config/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/invoke_training/config/base_pipeline_config.py",
"chars": 2184,
"preview": "import typing\nfrom typing import Optional\n\nfrom invoke_training.config.config_base_model import ConfigBaseModel\n\n\nclass "
},
{
"path": "src/invoke_training/config/config_base_model.py",
"chars": 249,
"preview": "from pydantic import BaseModel, ConfigDict\n\n\nclass ConfigBaseModel(BaseModel):\n \"\"\"Base model for all invoke training"
},
{
"path": "src/invoke_training/config/data/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/invoke_training/config/data/data_loader_config.py",
"chars": 7564,
"preview": "from typing import Literal, Optional\n\nfrom invoke_training.config.config_base_model import ConfigBaseModel\nfrom invoke_t"
},
{
"path": "src/invoke_training/config/data/dataset_config.py",
"chars": 2917,
"preview": "from typing import Annotated, Literal, Optional, Union\n\nfrom pydantic import Field\n\nfrom invoke_training.config.config_b"
},
{
"path": "src/invoke_training/config/optimizer/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/invoke_training/config/optimizer/optimizer_config.py",
"chars": 1500,
"preview": "import typing\n\nfrom invoke_training.config.config_base_model import ConfigBaseModel\n\n\nclass AdamOptimizerConfig(ConfigBa"
},
{
"path": "src/invoke_training/config/pipeline_config.py",
"chars": 1198,
"preview": "from typing import Annotated, Union\n\nfrom pydantic import Field\n\nfrom invoke_training.pipelines._experimental.sd_dpo_lor"
},
{
"path": "src/invoke_training/model_merge/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/invoke_training/model_merge/extract_lora.py",
"chars": 3642,
"preview": "import torch\nimport tqdm\nfrom peft.peft_model import PeftModel\n\n# All original base model weights in a PeftModel have th"
},
{
"path": "src/invoke_training/model_merge/merge_models.py",
"chars": 3578,
"preview": "from typing import Literal\n\nimport torch\nimport tqdm\n\nfrom invoke_training.model_merge.utils.normalize_weights import no"
},
{
"path": "src/invoke_training/model_merge/merge_tasks_to_base.py",
"chars": 2929,
"preview": "from typing import Literal\n\nimport torch\nimport tqdm\nfrom peft.utils.merge_utils import dare_linear, dare_ties, ties\n\n\n@"
},
{
"path": "src/invoke_training/model_merge/scripts/extract_lora_from_model_diff.py",
"chars": 13509,
"preview": "# This script is based on\n# https://raw.githubusercontent.com/kohya-ss/sd-scripts/bfb352bc433326a77aca3124248331eb60c49e"
},
{
"path": "src/invoke_training/model_merge/scripts/merge_lora_into_model.py",
"chars": 7109,
"preview": "import argparse # noqa: I001\nimport logging\nfrom pathlib import Path\n\nimport torch\nfrom diffusers import StableDiffusio"
},
{
"path": "src/invoke_training/model_merge/scripts/merge_models.py",
"chars": 5432,
"preview": "import argparse\nimport logging\nfrom dataclasses import dataclass\nfrom pathlib import Path\n\nimport torch\nfrom diffusers i"
},
{
"path": "src/invoke_training/model_merge/scripts/merge_task_models_to_base_model.py",
"chars": 6717,
"preview": "import argparse\nimport logging\nfrom pathlib import Path\n\nimport torch\nfrom diffusers import StableDiffusionPipeline, Sta"
},
{
"path": "src/invoke_training/model_merge/utils/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/invoke_training/model_merge/utils/normalize_weights.py",
"chars": 135,
"preview": "def normalize_weights(weights: list[float]) -> list[float]:\n total = sum(weights)\n return [weight / total for weig"
},
{
"path": "src/invoke_training/model_merge/utils/parse_model_arg.py",
"chars": 378,
"preview": "def parse_model_arg(model: str, delimiter: str = \"::\") -> tuple[str, str | None]:\n \"\"\"Parse a model argument into a m"
},
{
"path": "src/invoke_training/pipelines/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/invoke_training/pipelines/_experimental/sd_dpo_lora/config.py",
"chars": 11403,
"preview": "from typing import Annotated, Literal, Union\n\nfrom pydantic import Field, model_validator\n\nfrom invoke_training.config.b"
},
{
"path": "src/invoke_training/pipelines/_experimental/sd_dpo_lora/train.py",
"chars": 27220,
"preview": "import copy\nimport itertools\nimport json\nimport logging\nimport math\nimport os\nimport tempfile\nimport time\nfrom pathlib i"
},
{
"path": "src/invoke_training/pipelines/callbacks.py",
"chars": 3289,
"preview": "from abc import ABC\nfrom enum import Enum\n\n\nclass ModelType(Enum):\n # At first glance, it feels like these model type"
},
{
"path": "src/invoke_training/pipelines/flux/lora/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/invoke_training/pipelines/flux/lora/config.py",
"chars": 10288,
"preview": "from typing import Annotated, Literal, Union\n\nfrom pydantic import Field\n\nfrom invoke_training._shared.flux.lora_checkpo"
},
{
"path": "src/invoke_training/pipelines/flux/lora/train.py",
"chars": 30814,
"preview": "import itertools\nimport json\nimport logging\nimport math\nimport os\nimport tempfile\nimport time\nfrom pathlib import Path\nf"
},
{
"path": "src/invoke_training/pipelines/invoke_train.py",
"chars": 2288,
"preview": "import os\n\nfrom invoke_training.config.pipeline_config import PipelineConfig\nfrom invoke_training.pipelines._experimenta"
},
{
"path": "src/invoke_training/pipelines/stable_diffusion/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/invoke_training/pipelines/stable_diffusion/lora/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/invoke_training/pipelines/stable_diffusion/lora/config.py",
"chars": 10750,
"preview": "from typing import Annotated, Literal, Union\n\nfrom pydantic import Field, model_validator\n\nfrom invoke_training._shared."
},
{
"path": "src/invoke_training/pipelines/stable_diffusion/lora/train.py",
"chars": 29952,
"preview": "import itertools\nimport json\nimport logging\nimport math\nimport os\nimport tempfile\nimport time\nfrom pathlib import Path\nf"
},
{
"path": "src/invoke_training/pipelines/stable_diffusion/textual_inversion/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/invoke_training/pipelines/stable_diffusion/textual_inversion/config.py",
"chars": 9523,
"preview": "from typing import Literal\n\nfrom pydantic import model_validator\n\nfrom invoke_training.config.base_pipeline_config impor"
},
{
"path": "src/invoke_training/pipelines/stable_diffusion/textual_inversion/train.py",
"chars": 19996,
"preview": "import json\nimport logging\nimport math\nimport os\nimport tempfile\nimport time\n\nimport torch\nfrom accelerate import Accele"
},
{
"path": "src/invoke_training/pipelines/stable_diffusion_xl/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/invoke_training/pipelines/stable_diffusion_xl/finetune/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/invoke_training/pipelines/stable_diffusion_xl/finetune/config.py",
"chars": 8491,
"preview": "from typing import Annotated, Literal, Union\n\nfrom pydantic import Field, model_validator\n\nfrom invoke_training.config.b"
},
{
"path": "src/invoke_training/pipelines/stable_diffusion_xl/finetune/train.py",
"chars": 20254,
"preview": "import itertools\nimport json\nimport logging\nimport math\nimport os\nimport tempfile\nimport time\nfrom typing import Literal"
},
{
"path": "src/invoke_training/pipelines/stable_diffusion_xl/lora/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/invoke_training/pipelines/stable_diffusion_xl/lora/config.py",
"chars": 11150,
"preview": "from typing import Annotated, Literal, Union\n\nfrom pydantic import Field, model_validator\n\nfrom invoke_training._shared."
},
{
"path": "src/invoke_training/pipelines/stable_diffusion_xl/lora/train.py",
"chars": 34092,
"preview": "import itertools\nimport json\nimport logging\nimport math\nimport os\nimport tempfile\nimport time\nfrom pathlib import Path\nf"
},
{
"path": "src/invoke_training/pipelines/stable_diffusion_xl/lora_and_textual_inversion/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/invoke_training/pipelines/stable_diffusion_xl/lora_and_textual_inversion/config.py",
"chars": 11458,
"preview": "from typing import Literal\n\nfrom pydantic import model_validator\n\nfrom invoke_training.config.base_pipeline_config impor"
},
{
"path": "src/invoke_training/pipelines/stable_diffusion_xl/lora_and_textual_inversion/train.py",
"chars": 25664,
"preview": "import itertools\nimport json\nimport logging\nimport math\nimport os\nimport time\nfrom pathlib import Path\nfrom typing impor"
},
{
"path": "src/invoke_training/pipelines/stable_diffusion_xl/textual_inversion/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/invoke_training/pipelines/stable_diffusion_xl/textual_inversion/config.py",
"chars": 9751,
"preview": "from typing import Literal\n\nfrom pydantic import model_validator\n\nfrom invoke_training.config.base_pipeline_config impor"
},
{
"path": "src/invoke_training/pipelines/stable_diffusion_xl/textual_inversion/train.py",
"chars": 22955,
"preview": "import json\nimport logging\nimport math\nimport os\nimport tempfile\nimport time\n\nimport torch\nimport torch.utils.data\nfrom "
},
{
"path": "src/invoke_training/sample_configs/_experimental/sd_dpo_lora_pickapic_1x24gb.yaml",
"chars": 1141,
"preview": "# Training mode: Direct Preference Optimization LoRA Training\n# Dataset: A small subset of the pickapic_v2 dataset.\n# Ba"
},
{
"path": "src/invoke_training/sample_configs/_experimental/sd_dpo_lora_refinement_pokemon_1x24gb.yaml",
"chars": 1014,
"preview": "# Training mode: Direct Preference Optimization LoRA Training\n# Base model: SD 1.5\n# GPU: 1 x 24GB\n\ntype: S"
},
{
"path": "src/invoke_training/sample_configs/flux_lora_1x40gb.yaml",
"chars": 1339,
"preview": "# Training mode: LoRA\n# Base model: Flux.1-dev\n# Dataset: Bruce the Gnome\n# GPU: 1 x 40GB\n\ntype: FLUX"
},
{
"path": "src/invoke_training/sample_configs/sd_lora_baroque_1x8gb.yaml",
"chars": 1567,
"preview": "# Training mode: Finetuning with LoRA\n# Base model: SD 1.5\n# Dataset: https://huggingface.co/datasets/InvokeAI/"
},
{
"path": "src/invoke_training/sample_configs/sd_textual_inversion_gnome_1x8gb.yaml",
"chars": 1185,
"preview": "# Training mode: Textual Inversion\n# Base model: SD v1\n# GPU: 1 x 24GB\n\ntype: SD_TEXTUAL_INVERSION\nseed: 1\n"
},
{
"path": "src/invoke_training/sample_configs/sdxl_finetune_baroque_1x24gb.yaml",
"chars": 1736,
"preview": "# Training mode: Full Finetuning\n# Base model: SDXL\n# Dataset: https://huggingface.co/datasets/InvokeAI/nga-bar"
},
{
"path": "src/invoke_training/sample_configs/sdxl_finetune_robocats_1x24gb.yaml",
"chars": 1424,
"preview": "# Training mode: Full finetune\n# Base model: SDXL\n# Dataset: Robocats\n# GPU: 1 x 24GB\n\ntype: SDXL_FIN"
},
{
"path": "src/invoke_training/sample_configs/sdxl_lora_and_ti_gnome_1x24gb.yaml",
"chars": 1181,
"preview": "# Training mode: Finetuning with LoRA and Textual Inversion\n# Base model: SDXL 1.0\n# GPU: 1 x 24GB\n\ntype: S"
},
{
"path": "src/invoke_training/sample_configs/sdxl_lora_baroque_1x24gb.yaml",
"chars": 1578,
"preview": "# Training mode: Finetuning with LoRA\n# Base model: SDXL 1.0\n# Dataset: https://huggingface.co/datasets/InvokeA"
},
{
"path": "src/invoke_training/sample_configs/sdxl_lora_baroque_1x8gb.yaml",
"chars": 1890,
"preview": "# Training mode: Finetuning with LoRA\n# Base model: SDXL 1.0\n# Dataset: https://huggingface.co/datasets/InvokeA"
},
{
"path": "src/invoke_training/sample_configs/sdxl_lora_masks_gnome_1x24gb.yaml",
"chars": 1107,
"preview": "# Training mode: LoRA with masks\n# Base model: SDXL 1.0\n# Dataset: Bruce the Gnome\n# GPU: 1 x 24GB\n\nt"
},
{
"path": "src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml",
"chars": 1122,
"preview": "# Training mode: Textual Inversion\n# Base model: SDXL\n# GPU: 1 x 24GB\n\ntype: SDXL_TEXTUAL_INVERSION\nseed: 1"
},
{
"path": "src/invoke_training/sample_configs/sdxl_textual_inversion_masks_gnome_1x24gb.yaml",
"chars": 1173,
"preview": "# Training mode: Textual Inversion with Masks\n# Base model: SDXL\n# GPU: 1 x 24GB\n\ntype: SDXL_TEXTUAL_INVERS"
},
{
"path": "src/invoke_training/scripts/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/invoke_training/scripts/_experimental/auto_caption/auto_caption_images.py",
"chars": 4023,
"preview": "import argparse\nimport json\nfrom pathlib import Path\n\nimport torch\nimport torch.utils.data\nfrom PIL import Image\nfrom tq"
},
{
"path": "src/invoke_training/scripts/_experimental/masks/clipseg.py",
"chars": 2228,
"preview": "import torch\nfrom PIL import Image\nfrom transformers import AutoProcessor, CLIPSegForImageSegmentation, CLIPSegProcessor"
},
{
"path": "src/invoke_training/scripts/_experimental/masks/generate_masks.py",
"chars": 3022,
"preview": "import argparse\nfrom pathlib import Path\n\nimport torch\nimport torch.utils.data\nfrom tqdm import tqdm\n\nfrom invoke_traini"
},
{
"path": "src/invoke_training/scripts/_experimental/masks/generate_masks_for_jsonl_dataset.py",
"chars": 5195,
"preview": "import argparse\nfrom pathlib import Path\n\nimport torch\nimport torch.utils.data\nfrom tqdm import tqdm\n\nfrom invoke_traini"
},
{
"path": "src/invoke_training/scripts/_experimental/rank_images.py",
"chars": 3899,
"preview": "import argparse\nimport os\nimport time\nfrom pathlib import Path\nfrom typing import Literal\n\nimport gradio as gr\nimport ya"
},
{
"path": "src/invoke_training/scripts/convert_sd_lora_to_kohya_format.py",
"chars": 1475,
"preview": "import argparse\nfrom pathlib import Path\n\nimport torch\n\nfrom invoke_training._shared.stable_diffusion.lora_checkpoint_ut"
},
{
"path": "src/invoke_training/scripts/invoke_generate_images.py",
"chars": 4689,
"preview": "import argparse\nfrom pathlib import Path\n\nfrom invoke_training._shared.stable_diffusion.model_loading_utils import Pipel"
},
{
"path": "src/invoke_training/scripts/invoke_train.py",
"chars": 846,
"preview": "import argparse\nfrom pathlib import Path\n\nimport yaml\nfrom pydantic import TypeAdapter\n\nfrom invoke_training.config.pipe"
},
{
"path": "src/invoke_training/scripts/invoke_train_ui.py",
"chars": 567,
"preview": "import argparse\n\nimport uvicorn\n\nfrom invoke_training.ui.app import build_app\n\n\ndef main():\n parser = argparse.Argume"
},
{
"path": "src/invoke_training/scripts/invoke_visualize_data_loading.py",
"chars": 4488,
"preview": "import argparse\nimport os\nimport time\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nimport yaml\nfrom PIL imp"
},
{
"path": "src/invoke_training/scripts/utils/image_dir_dataset.py",
"chars": 1756,
"preview": "import os\nimport typing\n\nimport torch\nfrom PIL import Image\n\n\nclass ImageDirDataset(torch.utils.data.Dataset):\n \"\"\"A "
},
{
"path": "src/invoke_training/ui/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/invoke_training/ui/app.py",
"chars": 875,
"preview": "from pathlib import Path\n\nimport gradio as gr\nfrom fastapi import FastAPI\nfrom fastapi.responses import FileResponse\nfro"
},
{
"path": "src/invoke_training/ui/config_groups/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/invoke_training/ui/config_groups/aspect_ratio_bucket_config_group.py",
"chars": 2725,
"preview": "from typing import Any\n\nimport gradio as gr\n\nfrom invoke_training.config.data.data_loader_config import AspectRatioBucke"
},
{
"path": "src/invoke_training/ui/config_groups/base_pipeline_config_group.py",
"chars": 6243,
"preview": "from typing import Any\n\nimport gradio as gr\n\nfrom invoke_training.config.base_pipeline_config import BasePipelineConfig\n"
},
{
"path": "src/invoke_training/ui/config_groups/dataset_config_group.py",
"chars": 12854,
"preview": "from typing import Any\n\nimport gradio as gr\n\nfrom invoke_training.config.data.dataset_config import (\n HFHubImageCapt"
},
{
"path": "src/invoke_training/ui/config_groups/flux_lora_config_group.py",
"chars": 20273,
"preview": "import typing\n\nimport gradio as gr\n\nfrom invoke_training.pipelines.flux.lora.config import FluxLoraConfig\nfrom invoke_tr"
},
{
"path": "src/invoke_training/ui/config_groups/image_caption_sd_data_loader_config_group.py",
"chars": 5961,
"preview": "from typing import Any\n\nimport gradio as gr\n\nfrom invoke_training.config.data.data_loader_config import ImageCaptionSDDa"
},
{
"path": "src/invoke_training/ui/config_groups/optimizer_config_group.py",
"chars": 6851,
"preview": "from typing import Any\n\nimport gradio as gr\n\nfrom invoke_training.config.optimizer.optimizer_config import AdamOptimizer"
},
{
"path": "src/invoke_training/ui/config_groups/sd_lora_config_group.py",
"chars": 14076,
"preview": "import typing\n\nimport gradio as gr\n\nfrom invoke_training.pipelines.stable_diffusion.lora.config import SdLoraConfig\nfrom"
},
{
"path": "src/invoke_training/ui/config_groups/sd_textual_inversion_config_group.py",
"chars": 13704,
"preview": "import typing\n\nimport gradio as gr\n\nfrom invoke_training.pipelines.stable_diffusion.textual_inversion.config import SdTe"
},
{
"path": "src/invoke_training/ui/config_groups/sdxl_finetune_config_group.py",
"chars": 13577,
"preview": "import typing\n\nimport gradio as gr\n\nfrom invoke_training.pipelines.stable_diffusion_xl.finetune.config import SdxlFinetu"
},
{
"path": "src/invoke_training/ui/config_groups/sdxl_lora_and_textual_inversion_config_group.py",
"chars": 17946,
"preview": "import typing\n\nimport gradio as gr\n\nfrom invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config"
},
{
"path": "src/invoke_training/ui/config_groups/sdxl_lora_config_group.py",
"chars": 14513,
"preview": "import typing\n\nimport gradio as gr\n\nfrom invoke_training.pipelines.stable_diffusion_xl.lora.config import SdxlLoraConfig"
},
{
"path": "src/invoke_training/ui/config_groups/sdxl_textual_inversion_config_group.py",
"chars": 14131,
"preview": "import typing\n\nimport gradio as gr\n\nfrom invoke_training.pipelines.stable_diffusion_xl.textual_inversion.config import S"
},
{
"path": "src/invoke_training/ui/config_groups/textual_inversion_sd_data_loader_config_group.py",
"chars": 7577,
"preview": "from typing import Any\n\nimport gradio as gr\n\nfrom invoke_training.config.data.data_loader_config import (\n TextualInv"
},
{
"path": "src/invoke_training/ui/config_groups/ui_config_element.py",
"chars": 1662,
"preview": "from typing import Any\n\nimport gradio as gr\n\n\nclass UIConfigElement:\n \"\"\"A base class for UI blocks that represent a "
},
{
"path": "src/invoke_training/ui/gradio_blocks/header.py",
"chars": 598,
"preview": "import gradio as gr\n\nfrom invoke_training.ui.utils.utils import get_assets_dir_path\n\n\nclass Header:\n def __init__(sel"
},
{
"path": "src/invoke_training/ui/gradio_blocks/pipeline_tab.py",
"chars": 6915,
"preview": "import typing\n\nimport gradio as gr\nimport yaml\n\nfrom invoke_training.config.pipeline_config import PipelineConfig\nfrom i"
},
{
"path": "src/invoke_training/ui/index.html",
"chars": 2058,
"preview": "<!DOCTYPE html>\n<html lang=\"en\">\n<head>\n <meta charset=\"UTF-8\">\n <meta name=\"viewport\" content=\"width=device-width"
},
{
"path": "src/invoke_training/ui/pages/data_page.py",
"chars": 14842,
"preview": "from pathlib import Path\n\nimport gradio as gr\nfrom PIL import Image\n\nfrom invoke_training._shared.data.datasets.image_ca"
},
{
"path": "src/invoke_training/ui/pages/training_page.py",
"chars": 8100,
"preview": "import os\nimport subprocess\nimport tempfile\nimport time\n\nimport gradio as gr\nimport yaml\n\nfrom invoke_training.config.pi"
},
{
"path": "src/invoke_training/ui/utils/prompts.py",
"chars": 3059,
"preview": "NEGATIVE_PROMPT_DELIMITER = \"[NEG]\"\n\n\ndef split_pos_neg_prompts(prompt: str) -> tuple[str, str]:\n \"\"\"Split a prompt c"
},
{
"path": "src/invoke_training/ui/utils/utils.py",
"chars": 1016,
"preview": "import typing\nfrom pathlib import Path\n\nimport yaml\nfrom pydantic import TypeAdapter\n\nfrom invoke_training.config.pipeli"
},
{
"path": "tests/invoke_training/_shared/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "tests/invoke_training/_shared/checkpoints/test_checkpoint_tracker.py",
"chars": 4020,
"preview": "import os\nimport tempfile\nfrom pathlib import Path\n\nimport pytest\n\nfrom invoke_training._shared.checkpoints.checkpoint_t"
},
{
"path": "tests/invoke_training/_shared/checkpoints/test_serialization.py",
"chars": 1217,
"preview": "import os\nimport tempfile\n\nimport pytest\nimport torch\n\nfrom invoke_training._shared.checkpoints.serialization import (\n "
},
{
"path": "tests/invoke_training/_shared/data/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "tests/invoke_training/_shared/data/data_loaders/__init__.py",
"chars": 0,
"preview": ""
}
]
// ... and 46 more files (download for full content)
About this extraction
This page contains the full source code of the invoke-ai/invoke-training GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 246 files (930.6 KB), approximately 219.8k tokens, and a symbol index with 581 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.