Repository: invoke-ai/invoke-training
Branch: main
Commit: 363f83cdb5e6
Files: 246
Total size: 930.6 KB
Directory structure:
gitextract_5kosyax7/
├── .github/
│ └── workflows/
│ ├── deploy.yaml
│ └── test.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── docs/
│ ├── contributing/
│ │ ├── development_environment.md
│ │ ├── directory_structure.md
│ │ ├── documentation.md
│ │ └── tests.md
│ ├── get-started/
│ │ ├── installation.md
│ │ └── quick-start.md
│ ├── guides/
│ │ ├── dataset_formats.md
│ │ ├── model_merge.md
│ │ └── stable_diffusion/
│ │ ├── dpo_lora_sd.md
│ │ ├── gnome_lora_masks_sdxl.md
│ │ ├── robocats_finetune_sdxl.md
│ │ └── textual_inversion_sdxl.md
│ ├── index.md
│ ├── reference/
│ │ └── config/
│ │ ├── index.md
│ │ ├── pipelines/
│ │ │ ├── sd_lora.md
│ │ │ ├── sd_textual_inversion.md
│ │ │ ├── sdxl_finetune.md
│ │ │ ├── sdxl_lora.md
│ │ │ ├── sdxl_lora_and_textual_inversion.md
│ │ │ └── sdxl_textual_inversion.md
│ │ └── shared/
│ │ ├── data/
│ │ │ ├── data_loader_config.md
│ │ │ └── dataset_config.md
│ │ └── optimizer_config.md
│ └── templates/
│ └── python/
│ └── material/
│ └── labels.html
├── mkdocs.yml
├── pyproject.toml
├── sample_data/
│ └── bruce_the_gnome/
│ └── data.jsonl
├── src/
│ └── invoke_training/
│ ├── __init__.py
│ ├── _shared/
│ │ ├── __init__.py
│ │ ├── accelerator/
│ │ │ ├── __init__.py
│ │ │ └── accelerator_utils.py
│ │ ├── checkpoints/
│ │ │ ├── __init__.py
│ │ │ ├── checkpoint_tracker.py
│ │ │ ├── lora_checkpoint_utils.py
│ │ │ └── serialization.py
│ │ ├── data/
│ │ │ ├── ARCHITECTURE.md
│ │ │ ├── __init__.py
│ │ │ ├── data_loaders/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── dreambooth_sd_dataloader.py
│ │ │ │ ├── image_caption_flux_dataloader.py
│ │ │ │ ├── image_caption_sd_dataloader.py
│ │ │ │ ├── image_pair_preference_sd_dataloader.py
│ │ │ │ └── textual_inversion_sd_dataloader.py
│ │ │ ├── datasets/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── build_dataset.py
│ │ │ │ ├── hf_image_caption_dataset.py
│ │ │ │ ├── hf_image_pair_preference_dataset.py
│ │ │ │ ├── image_caption_dir_dataset.py
│ │ │ │ ├── image_caption_jsonl_dataset.py
│ │ │ │ ├── image_dir_dataset.py
│ │ │ │ ├── image_pair_preference_dataset.py
│ │ │ │ └── transform_dataset.py
│ │ │ ├── samplers/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── aspect_ratio_bucket_batch_sampler.py
│ │ │ │ ├── batch_offset_sampler.py
│ │ │ │ ├── concat_sampler.py
│ │ │ │ ├── interleaved_sampler.py
│ │ │ │ └── offset_sampler.py
│ │ │ ├── transforms/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── caption_prefix_transform.py
│ │ │ │ ├── concat_fields_transform.py
│ │ │ │ ├── constant_field_transform.py
│ │ │ │ ├── drop_field_transform.py
│ │ │ │ ├── flux_image_transform.py
│ │ │ │ ├── load_cache_transform.py
│ │ │ │ ├── sd_image_transform.py
│ │ │ │ ├── shuffle_caption_transform.py
│ │ │ │ ├── template_caption_transform.py
│ │ │ │ └── tensor_disk_cache.py
│ │ │ └── utils/
│ │ │ ├── __init__.py
│ │ │ ├── aspect_ratio_bucket_manager.py
│ │ │ ├── resize.py
│ │ │ └── resolution.py
│ │ ├── flux/
│ │ │ ├── encoding_utils.py
│ │ │ ├── lora_checkpoint_utils.py
│ │ │ ├── model_loading_utils.py
│ │ │ └── validation.py
│ │ ├── optimizer/
│ │ │ ├── __init__.py
│ │ │ └── optimizer_utils.py
│ │ ├── stable_diffusion/
│ │ │ ├── __init__.py
│ │ │ ├── base_model_version.py
│ │ │ ├── checkpoint_utils.py
│ │ │ ├── lora_checkpoint_utils.py
│ │ │ ├── min_snr_weighting.py
│ │ │ ├── model_loading_utils.py
│ │ │ ├── textual_inversion.py
│ │ │ ├── tokenize_captions.py
│ │ │ └── validation.py
│ │ ├── tools/
│ │ │ ├── __init__.py
│ │ │ └── generate_images.py
│ │ └── utils/
│ │ ├── import_xformers.py
│ │ └── jsonl.py
│ ├── config/
│ │ ├── __init__.py
│ │ ├── base_pipeline_config.py
│ │ ├── config_base_model.py
│ │ ├── data/
│ │ │ ├── __init__.py
│ │ │ ├── data_loader_config.py
│ │ │ └── dataset_config.py
│ │ ├── optimizer/
│ │ │ ├── __init__.py
│ │ │ └── optimizer_config.py
│ │ └── pipeline_config.py
│ ├── model_merge/
│ │ ├── __init__.py
│ │ ├── extract_lora.py
│ │ ├── merge_models.py
│ │ ├── merge_tasks_to_base.py
│ │ ├── scripts/
│ │ │ ├── extract_lora_from_model_diff.py
│ │ │ ├── merge_lora_into_model.py
│ │ │ ├── merge_models.py
│ │ │ └── merge_task_models_to_base_model.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ ├── normalize_weights.py
│ │ └── parse_model_arg.py
│ ├── pipelines/
│ │ ├── __init__.py
│ │ ├── _experimental/
│ │ │ └── sd_dpo_lora/
│ │ │ ├── config.py
│ │ │ └── train.py
│ │ ├── callbacks.py
│ │ ├── flux/
│ │ │ └── lora/
│ │ │ ├── __init__.py
│ │ │ ├── config.py
│ │ │ └── train.py
│ │ ├── invoke_train.py
│ │ ├── stable_diffusion/
│ │ │ ├── __init__.py
│ │ │ ├── lora/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── config.py
│ │ │ │ └── train.py
│ │ │ └── textual_inversion/
│ │ │ ├── __init__.py
│ │ │ ├── config.py
│ │ │ └── train.py
│ │ └── stable_diffusion_xl/
│ │ ├── __init__.py
│ │ ├── finetune/
│ │ │ ├── __init__.py
│ │ │ ├── config.py
│ │ │ └── train.py
│ │ ├── lora/
│ │ │ ├── __init__.py
│ │ │ ├── config.py
│ │ │ └── train.py
│ │ ├── lora_and_textual_inversion/
│ │ │ ├── __init__.py
│ │ │ ├── config.py
│ │ │ └── train.py
│ │ └── textual_inversion/
│ │ ├── __init__.py
│ │ ├── config.py
│ │ └── train.py
│ ├── sample_configs/
│ │ ├── _experimental/
│ │ │ ├── sd_dpo_lora_pickapic_1x24gb.yaml
│ │ │ └── sd_dpo_lora_refinement_pokemon_1x24gb.yaml
│ │ ├── flux_lora_1x40gb.yaml
│ │ ├── sd_lora_baroque_1x8gb.yaml
│ │ ├── sd_textual_inversion_gnome_1x8gb.yaml
│ │ ├── sdxl_finetune_baroque_1x24gb.yaml
│ │ ├── sdxl_finetune_robocats_1x24gb.yaml
│ │ ├── sdxl_lora_and_ti_gnome_1x24gb.yaml
│ │ ├── sdxl_lora_baroque_1x24gb.yaml
│ │ ├── sdxl_lora_baroque_1x8gb.yaml
│ │ ├── sdxl_lora_masks_gnome_1x24gb.yaml
│ │ ├── sdxl_textual_inversion_gnome_1x24gb.yaml
│ │ └── sdxl_textual_inversion_masks_gnome_1x24gb.yaml
│ ├── scripts/
│ │ ├── __init__.py
│ │ ├── _experimental/
│ │ │ ├── auto_caption/
│ │ │ │ └── auto_caption_images.py
│ │ │ ├── masks/
│ │ │ │ ├── clipseg.py
│ │ │ │ ├── generate_masks.py
│ │ │ │ └── generate_masks_for_jsonl_dataset.py
│ │ │ └── rank_images.py
│ │ ├── convert_sd_lora_to_kohya_format.py
│ │ ├── invoke_generate_images.py
│ │ ├── invoke_train.py
│ │ ├── invoke_train_ui.py
│ │ ├── invoke_visualize_data_loading.py
│ │ └── utils/
│ │ └── image_dir_dataset.py
│ └── ui/
│ ├── __init__.py
│ ├── app.py
│ ├── config_groups/
│ │ ├── __init__.py
│ │ ├── aspect_ratio_bucket_config_group.py
│ │ ├── base_pipeline_config_group.py
│ │ ├── dataset_config_group.py
│ │ ├── flux_lora_config_group.py
│ │ ├── image_caption_sd_data_loader_config_group.py
│ │ ├── optimizer_config_group.py
│ │ ├── sd_lora_config_group.py
│ │ ├── sd_textual_inversion_config_group.py
│ │ ├── sdxl_finetune_config_group.py
│ │ ├── sdxl_lora_and_textual_inversion_config_group.py
│ │ ├── sdxl_lora_config_group.py
│ │ ├── sdxl_textual_inversion_config_group.py
│ │ ├── textual_inversion_sd_data_loader_config_group.py
│ │ └── ui_config_element.py
│ ├── gradio_blocks/
│ │ ├── header.py
│ │ └── pipeline_tab.py
│ ├── index.html
│ ├── pages/
│ │ ├── data_page.py
│ │ └── training_page.py
│ └── utils/
│ ├── prompts.py
│ └── utils.py
└── tests/
└── invoke_training/
├── _shared/
│ ├── __init__.py
│ ├── checkpoints/
│ │ ├── test_checkpoint_tracker.py
│ │ └── test_serialization.py
│ ├── data/
│ │ ├── __init__.py
│ │ ├── data_loaders/
│ │ │ ├── __init__.py
│ │ │ ├── test_dreambooth_sd_dataloader.py
│ │ │ ├── test_image_caption_sd_dataloader.py
│ │ │ ├── test_image_pair_preference_sd_dataloader.py
│ │ │ └── test_textual_inversion_sd_dataloader.py
│ │ ├── dataset_fixtures.py
│ │ ├── datasets/
│ │ │ ├── __init__.py
│ │ │ ├── test_hf_image_caption_dataset.py
│ │ │ ├── test_hf_image_pair_preference_dataset.py
│ │ │ ├── test_image_caption_dir_dataset.py
│ │ │ ├── test_image_caption_jsonl_dataset.py
│ │ │ ├── test_image_dir_dataset.py
│ │ │ ├── test_image_pair_preference_dataset.py
│ │ │ └── test_transform_dataset.py
│ │ ├── samplers/
│ │ │ ├── __init__.py
│ │ │ ├── test_aspect_ratio_bucket_batch_sampler.py
│ │ │ ├── test_batch_offset_sampler.py
│ │ │ ├── test_concat_sampler.py
│ │ │ ├── test_interleaved_sampler.py
│ │ │ └── test_offset_sampler.py
│ │ ├── transforms/
│ │ │ ├── __init__.py
│ │ │ ├── test_caption_prefix_transform.py
│ │ │ ├── test_concat_fields_transform.py
│ │ │ ├── test_constant_field_transform.py
│ │ │ ├── test_drop_field_transform.py
│ │ │ ├── test_load_cache_transform.py
│ │ │ ├── test_sd_image_transform.py
│ │ │ ├── test_shuffle_caption_transform.py
│ │ │ ├── test_template_caption_transform.py
│ │ │ └── test_tensor_disk_cache.py
│ │ └── utils/
│ │ ├── __init__.py
│ │ ├── test_aspect_ratio_bucket_manager.py
│ │ ├── test_resize.py
│ │ └── test_resolution.py
│ ├── stable_diffusion/
│ │ ├── __init__.py
│ │ ├── test_base_model_version.py
│ │ ├── test_lora_checkpoint_utils.py
│ │ ├── test_model_loading_utils.py
│ │ ├── test_textual_inversion.py
│ │ └── ti_embedding_checkpoint_fixture.py
│ └── utils/
│ └── test_jsonl.py
├── config/
│ └── pipelines/
│ └── test_pipeline_config.py
├── model_merge/
│ ├── __init__.py
│ ├── test_merge_models.py
│ ├── test_merge_tasks_to_base.py
│ └── utils.py
└── ui/
└── utils/
└── test_prompts.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/workflows/deploy.yaml
================================================
name: Deploy invoke-training docs
on:
push:
branches:
- main
permissions:
contents: write
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Configure Git Credentials
run: |
git config user.name github-actions[bot]
git config user.email 41898282+github-actions[bot]@users.noreply.github.com
- uses: actions/setup-python@v4
with:
python-version: "3.10"
cache: pip
cache-dependency-path: pyproject.toml
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install .[test]
- run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV
- uses: actions/cache@v3
with:
key: mkdocs-material-${{ env.cache_id }}
path: .cache
restore-keys: |
mkdocs-material-
- run: mkdocs gh-deploy --force
================================================
FILE: .github/workflows/test.yaml
================================================
name: Test invoke-training
on:
push:
branches:
- main
pull_request:
workflow_dispatch:
jobs:
build:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.12"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: pip
cache-dependency-path: pyproject.toml
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install .[test]
- name: Ruff lint
run: |
ruff check --output-format=github .
- name: Ruff format
run: |
ruff format --check .
- name: Test with pytest
run: |
pytest tests --junitxml=junit/test-results-${{ matrix.python-version }}.xml -m "not cuda and not loads_model"
- name: Upload pytest test results
uses: actions/upload-artifact@v4
with:
name: pytest-results-${{ matrix.python-version }}
path: junit/test-results-${{ matrix.python-version }}.xml
# Use always() to always run this step to publish test results when there are test failures.
if: ${{ always() }}
================================================
FILE: .gitignore
================================================
/output/
/test_configs/
/data/
# pyenv
.python-version
# VSCode
.vscode/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
junit/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.aider*
================================================
FILE: .pre-commit-config.yaml
================================================
# See https://pre-commit.com/ for usage and config.
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.1.7
hooks:
# Run the linter.
- id: ruff
# Run the formatter.
- id: ruff-format
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# invoke-training
A library for training custom Stable Diffusion models (fine-tuning, LoRA training, textual inversion, etc.) that can be used in [InvokeAI](https://github.com/invoke-ai/InvokeAI).
> [!WARNING] > `invoke-training` is still under active development, and breaking changes are likely. Full backwards compatibility will not be guaranteed until v1.0.0.
> In the meantime, I recommend pinning to a specific commit hash.
## Documentation
## Training Modes
- Stable Diffusion
- LoRA
- DreamBooth LoRA
- Textual Inversion
- Stable Diffusion XL
- Full finetuning
- LoRA
- DreamBooth LoRA
- Textual Inversion
- LoRA and Textual Inversion
More training modes coming soon!
## Installation
See the [Installation](https://invoke-ai.github.io/invoke-training/get-started/installation/) section of the documentation.
## Quick Start
`invoke-training` pipelines can be configured and launched from either the CLI or the GUI.
### CLI
Run training via the CLI with type-checked YAML configuration files for maximum control:
```bash
invoke-train --cfg-file src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml
```
### GUI
Run training via the GUI for a simpler starting point.
```bash
invoke-train-ui
# Or, you can optionally override the default host and port:
invoke-train-ui --host 0.0.0.0 --port 1234
```
## Features
Training progress can be monitored with [Tensorboard](https://www.tensorflow.org/tensorboard):

_Validation images in the Tensorboard UI._
All trained models are compatible with InvokeAI:

_Example image generated with the prompt "A cute yoda pokemon creature." and a trained Pokemon LoRA._
## Contributing
Contributors are welcome. For developer guidance, see the [Contributing](https://invoke-ai.github.io/invoke-training/contributing/development_environment/) section of the documentation.
================================================
FILE: docs/contributing/development_environment.md
================================================
# Development Environment Setup
See the [developer installation instructions](../get-started/installation.md#developer-installation).
================================================
FILE: docs/contributing/directory_structure.md
================================================
# Directory Structure
```bash
invoke-training/
├── README.md
├── docs/
├── src/
│ └── invoke-training/
│ ├── _shared/ # Utilities shared across multiple pipelines. Hight unit test coverage.
│ ├── config/ # Config structures shared by multiple pipelines.
│ ├── pipelines/ # Each pipeline is isolated in it's own directory with a train.py and config.py.
│ │ ├── stable_diffusion/
│ │ │ ├── lora/
│ │ │ │ ├── config.py
│ │ │ │ └── train.py
│ │ │ └── textual_inversion/
│ │ │ └── ...
│ │ ├── stable_diffusion_xl/
│ │ └── ...
│ └── scripts/ # Main entrypoints.
└── tests/ # Mirrors src/ directory.
```
================================================
FILE: docs/contributing/documentation.md
================================================
# Documentation
The documentation site is generated using [mkdocs](https://www.mkdocs.org/) and [mkdocstrings-python](https://mkdocstrings.github.io/python/).
To view your documentation changes locally, run `mkdocs serve`.
================================================
FILE: docs/contributing/tests.md
================================================
# Tests
Run all unit tests with:
```bash
pytest tests/
```
There are some test 'markers' defined in [pyproject.toml](https://github.com/invoke-ai/invoke-training/blob/main/pyproject.toml) that can be used to skip some tests. For example, the following command skips tests that require a GPU or require downloading model weights:
```bash
pytest tests/ -m "not cuda and not loads_model"
```
================================================
FILE: docs/get-started/installation.md
================================================
# Installation
## Requirements
1. Python 3.10, 3.11 and 3.12 are currently supported. Check your Python version by running `python -V`.
2. An NVIDIA GPU with >= 8 GB VRAM is recommended for model training.
## Basic Installation
0. Open your terminal and navigate to the directory where you want to clone the `invoke-training` repo.
1. Clone the repo:
```bash
git clone https://github.com/invoke-ai/invoke-training.git
```
2. Create and activate a python [virtual environment](https://docs.python.org/3/library/venv.html#creating-virtual-environments). This creates an isolated environment for `invoke-training` and its dependencies that won't interfere with other python environments on your system, including any installations of [InvokeAI](https://www.github.com/invoke-ai/invokeai).
```bash
# Navigate to the invoke-training directory.
cd invoke-training
# Create a new virtual environment named `invoketraining`.
python -m venv invoketraining
# Activate the new virtual environment.
# On Windows:
.\invoketraining\Scripts\activate
# On MacOS / Linux:
source invoketraining/bin/activate
```
3. Install `invoke-training` and its dependencies. Run the appropriate install command for your system.
```bash
# A recent version of pip is required, so first upgrade pip:
python -m pip install --upgrade pip
# Install - Windows or Linux with a Nvidia GPU:
pip install ".[test]" --extra-index-url https://download.pytorch.org/whl/cu126
# Install - Linux with no GPU:
pip install ".[test]" --extra-index-url https://download.pytorch.org/whl/cpu
# Install - All other systems:
pip install ".[test]"
```
In the future, before you run `invoke-training`, you must activate the virtual environment you created during installation, using the same command you used during installation.
## Developer Installation
Consider forking the repo if you plan to contribute code changes.
Follow the above installation instructions, cloning your fork instead of this repo if you made a fork.
Next, we suggest setting up the repo's pre-commit hooks to automatically format and lint your contributions:
1. (_Optional_) Install the pre-commit hooks: `pre-commit install`. This will run static analysis tools (ruff) on `git commit`.
2. (_Optional_) Setup `ruff` in your IDE of choice.
================================================
FILE: docs/get-started/quick-start.md
================================================
# Quick Start
`invoke-training` has both a GUI and a CLI (for advanced users). The instructions for getting started with both options can be found on this page.
There is also a video introduction to `invoke-training`:
## Quick Start - GUI
### 1. Installation
Follow the [`invoke-training` installation instructions](./installation.md).
### 2. Launch the GUI
Activate the virtual environment you created during installation, using the same command you used during installation.
You'll need to do this every time you run `invoke-training`.
```bash
# From the invoke-training directory:
invoke-train-ui
# Or, you can optionally override the default host and port:
invoke-train-ui --host 0.0.0.0 --port 1234
```
Access the GUI in your browser at the URL printed to the console.
### 3. Configure the training job
Select the desired training pipeline type in the top-level tab.
For this tutorial, we don't need to change any of the configuration values. The preset configuration should work well.
### 4. Generate the YAML configuration
Click on 'Generate Config' to generate a YAML configuration file. This YAML configuration file could be used to launch the training job from the CLI, if desired.
### 5. Start training
Click on the 'Start Training' and check your terminal for progress logs.
### 6. Monitor training
Monitor the training process with Tensorboard by running `tensorboard --logdir output/` and visiting [localhost:6006](http://localhost:6006) in your browser. Here you can see generated validation images throughout the training process.

_Validation images in the Tensorboard UI._
### 7. Invokeai
Select a checkpoint based on the quality of the generated images.
If you haven't already, setup [InvokeAI](https://github.com/invoke-ai/InvokeAI) by following its documentation.
Copy your selected LoRA checkpoint into your `${INVOKEAI_ROOT}/autoimport/lora` directory. For example:
```bash
# Note: You will have to replace the timestamp in the checkpoint path.
cp output/1691088769.5694647/checkpoint_epoch-00000002.safetensors ${INVOKEAI_ROOT}/autoimport/lora/pokemon_epoch-00000002.safetensors
```
You can now use your trained Pokemon LoRA in the InvokeAI UI! 🎉

_Example image generated with the prompt "A cute yoda pokemon creature." and Pokemon LoRA._
## Quick Start - CLI
### 1. Installation
Follow the [`invoke-training` installation instructions](./installation.md).
### 2. Training
Activate the virtual environment you created during installation, using the same command you used during installation.
You'll need to do this every time you run `invoke-training`.
See the [Textual Inversion - SDXL](../guides/stable_diffusion/textual_inversion_sdxl.md) tutorial for instructions on how to train a model via the CLI.
================================================
FILE: docs/guides/dataset_formats.md
================================================
# Dataset Formats
`invoke-training` supports the following dataset formats:
- `IMAGE_CAPTION_JSONL_DATASET`: A local image-caption dataset described by a single `.jsonl` file.
- `IMAGE_CAPTION_DIR_DATASET`: A local directory of images with associated `.txt` caption files.
- `IMAGE_DIR_DATASET`: A local directory of images (without captions).
- `HF_HUB_IMAGE_CAPTION_DATASET`: A Hugging Face Hub dataset containing images and captions.
See the documentation for a particular training pipeline to see which dataset formats it supports.
The following sections explain each of these formats in more detail.
## `IMAGE_CAPTION_JSONL_DATASET`
Config documentation: [ImageCaptionJsonlDatasetConfig][invoke_training.config.data.dataset_config.ImageCaptionJsonlDatasetConfig]
A `IMAGE_CAPTION_JSONL_DATASET` consists of a single `.jsonl` file containing image paths and associated captions.
Sample directory structure:
```bash
my_custom_dataset/
├── data.jsonl
└── train/
├── 0001.png
├── 0002.png
├── 0003.png
└── ...
```
The contents of `data.jsonl` would be:
```json
{"file_name": "train/0001.png", "text": "This is a caption describing image 0001."}
{"file_name": "train/0002.png", "text": "This is a caption describing image 0002."}
{"file_name": "train/0003.png", "text": "This is a caption describing image 0003."}
```
The image file paths can be either absolute paths, or relative to the `.jsonl` file.
Finally, this dataset can be used with the following pipeline dataset configuration:
```yaml
type: IMAGE_CAPTION_JSONL_DATASET
jsonl_path: /path/to/my_custom_dataset/metadata.jsonl
image_column: file_name
caption_column: text
```
A useful characteristic of this dataset format is that a `.jsonl` file can reference an image file anywhere on the local disk. It is common to maintain multiple `.jsonl` datasets that reference some of the same images without needing multiple copies of those images on disk.
## `IMAGE_CAPTION_DIR_DATASET`
Config documentation: [ImageCaptionDirDataset][invoke_training.config.data.dataset_config.ImageCaptionDirDatasetConfig]
A `IMAGE_CAPTION_DIR_DATASET` consists of a directory of image files and corresponding `.txt` caption files of the same name.
Sample directory structure:
```bash
my_custom_dataset/
├── 0001.png
├── 0001.txt
├── 0002.jpg
├── 0002.txt
├── 0003.png
├── 0003.txt
└── ...
```
Each `.txt` file should contain a caption on the first line of the file. Here are the sample contents of `0001.txt`:
```txt title="0001.txt"
this is a caption for example 0001
```
This dataset can be used with the following pipeline dataset configuration:
```yaml
type: IMAGE_CAPTION_DIR_DATASET
dataset_dir: /path/to/my_custom_dataset
```
## `IMAGE_DIR_DATASET`
Config documentation: [ImageDirDataset][invoke_training.config.data.dataset_config.ImageDirDatasetConfig]
A `IMAGE_DIR_DATASET` consists of a single directory of images (without captions).
Sample directory structure:
```bash
my_custom_dataset/
├── 0001.png
├── 0002.jpg
├── 0003.png
└── ...
```
This dataset can be used with the following pipeline dataset configuration:
```yaml
type: IMAGE_DIR_DATASET
dataset_dir: /path/to/my_custom_dataset
```
## `HF_HUB_IMAGE_CAPTION_DATASET`
Config documentation: [HFHubImageCaptionDatasetConfig][invoke_training.config.data.dataset_config.HFHubImageCaptionDatasetConfig]
The `HF_HUB_IMAGE_CAPTION_DATASET` dataset format can be used to access publicly datasets on the [Hugging Face Hub](https://huggingface.co/datasets). You can filter for the `Text-to-Image` task to find relevant datasets that contain both an image column and a caption column. [lambdalabs/pokemon-blip-captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) is a popular choice if you're not sure where to start.
================================================
FILE: docs/guides/model_merge.md
================================================
# Model Merging
`invoke-training` provides utility scripts for several common model merging workflows. This page contains a summary of the available tools.
## `extract_lora_from_model_diff.py`
Extract a LoRA model that represents the difference between two base models.
Note that the extracted LoRA model is a lossy representation of the difference between the models, so some degradation in quality is expected.
For usage docs, run:
```bash
python src/invoke_training/model_merge/scripts/extract_lora_from_model_diff.py -h
```
## `merge_lora_into_model.py`
Merge a LoRA model into a base model to produce a new base model.
For usage docs, run:
```bash
python src/invoke_training/model_merge/scripts/merge_lora_into_model.py -h
```
## `merge_models.py`
Merge 2 or more base models to produce a single base model (using either LERP or SLERP). This is a simple merge strategy that merges all model weights in the same way.
For usage docs, run:
```bash
python src/invoke_training/model_merge/scripts/merge_models.py -h
```
## `merge_task_models_to_base_model.py`
Merge 1 or more task-specific base models into a single starting base model (using either [TIES](https://arxiv.org/abs/2306.01708) or [DARE](https://arxiv.org/abs/2311.03099)). This merge strategy aims to preserve the task-specific behaviors of the task models while making only small changes to the original base model. This approach enables multiple task models to be merged without excessive interference between them.
If you want to merge a task-specific LoRA into a base model using this strategy, first use `merge_lora_into_model.py` to produce a task-specific base model, then merge that new base model using this strategy.
For usage docs, run:
```bash
python src/invoke_training/model_merge/scripts/merge_task_models_to_base_model.py -h
```
================================================
FILE: docs/guides/stable_diffusion/dpo_lora_sd.md
================================================
# (Experimental) Diffusion DPO - SD
!!! tip "Experimental"
The Diffusion Direct Preference Optimization training pipeline is still experimental. Support may be dropped at any time.
This tutorial walks through some initial experiments around using Diffusion Direct Preference Optimization (DPO) ([paper](https://arxiv.org/abs/2311.12908)) to train Stable Diffusion LoRA models.
## Experiment 1: `pickapic_v2` LoRA Training
The Diffusion-DPO paper does full model fine-tuning on the [pickapic_v2](https://huggingface.co/datasets/yuvalkirstain/pickapic_v2) dataset, which consists of roughly 1M AI-generated image pairs with preference annotations. In this experiment, we attempt to fine-tune a Stable Diffusion LoRA model using a small subset of the pickapic_v2 dataset.
Run this experiment with the following command:
```bash
invoke-train -c src/invoke_training/sample_configs/_experimental/sd_dpo_lora_pickapic_1x24gb.yaml
```
Here is a cherry-picked example of a prompt for which this training process was clearly beneficial.
Prompt: "*A galaxy-colored figurine is floating over the sea at sunset, photorealistic*"
| Before DPO Training | After DPO Training (same seed)|
| - | - |
|  |  |
## Experiment 2: LoRA Model Refinement
As a second experiment, we attempt the following workflow:
1. Train a Stable Diffusion LoRA model on a particular style.
2. Generate pairs of images of the character with the trained LoRA model.
3. Annotate the preferred image from each pair.
4. Apply Diffusion-DPO to the preference-annotated pairs to further fine-tune the LoRA model.
Note: The steps listed below are pretty rough. They are included primarily for reference for someone looking to resume this line of work in the future.
### 1. Train a style LoRA
```bash
invoke-train -c src/invoke_training/sample_configs/sd_lora_pokemon_1x8gb.yaml
```
### 2. Generate images
Prepare ~100 relevant prompts that will be used to generate training data with the freshly-trained LoRA model. Add the prompts to a `.txt` file - one prompt per line.
Example prompts:
```txt
a cute orange pokemon character with pointy ears
a drawing of a purple fish
a cartoon blob with a smile on its face
a drawing of a snail with big eyes
...
```
```bash
# Convert the LoRA checkpoint of interest to Kohya format.
# You will have to change the path timestamps in this example command.
# TODO(ryand): This manual conversion shouldn't be necessary.
python src/invoke_training/scripts/convert_sd_lora_to_kohya_format.py \
--src-ckpt-dir output/sd_lora_pokemon/1704824279.2765746/checkpoint_epoch-00000003/ \
--dst-ckpt-file output/sd_lora_pokemon/1704824279.2765746/checkpoint_epoch-00000003_kohya.safetensors
# Generate 2 pairs of images for each prompt.
invoke-generate-images \
-o output/pokemon_pairs \
-m runwayml/stable-diffusion-v1-5 \
-v fp16 \
-l output/sd_lora_pokemon/1704824279.2765746/checkpoint_epoch-00000003_kohya.safetensors \
--sd-version SD \
--prompt-file path/to/prompts.txt \
--set-size 2 \
--num-sets 2 \
--height 512 \
--width 512
```
### 3. Annotate the image pair preferences
Launch the gradio UI for selecting image pair preferences.
```bash
# Note: rank_images.py accepts a full training pipeline config, but only uses the dataset configuration.
python src/invoke_training/scripts/_experimental/rank_images.py -c src/invoke_training/sample_configs/_experimental/sd_dpo_lora_refinement_pokemon_1x24gb.yaml
```
After completing the pair annotations, click "Save Metadata" and move the resultant metadata file to your image data directory (e.g. `output/pokemon_pairs/metadata.jsonl`).
### 4. Run Diffusion-DPO
```bash
invoke-train -c src/invoke_training/sample_configs/_experimental/sd_dpo_lora_refinement_pokemon_1x24gb.yaml
```
================================================
FILE: docs/guides/stable_diffusion/gnome_lora_masks_sdxl.md
================================================
# LoRA with Masks - SDXL
This tutorial explains how to prepare masks for an image dataset and then use that dataset to train an SDXL LoRA model.
Masks can be used to weight regions of images in a dataset to control how much they contribute to the training process. In this tutorial we will use masks to train on a small dataset of images of Bruce the Gnome (4 images). With such a small dataset, there is a high risk of overfitting to the background elements from the images. We will use masks to avoid this problem ond focus only on the object of interest.
## 1 - Dataset Preparation
For this tutorial, we'll use a dataset consisting of 4 images of Bruce the Gnome:
| | |
| - | - |
|  |  |
|  |  |
This sample dataset is included in the invoke-training repo under [sample_data/bruce_the_gnome](https://github.com/invoke-ai/invoke-training/tree/main/sample_data/bruce_the_gnome).
## 2 - Generate Masks
Use the `generate_masks_for_jsonl_dataset.py` script to generate masks for your dataset based on a single prompt. In this case we are using the prompt `"a stuffed gnome"`:
```bash
python src/invoke_training/scripts/_experimental/masks/generate_masks_for_jsonl_dataset.py \
--in-jsonl sample_data/bruce_the_gnome/data.jsonl \
--out-jsonl sample_data/bruce_the_gnome/data_masks.jsonl \
--prompt "a stuffed gnome"
```
The mask generation script will produce the following outputs:
- A directory of generated masks: `sample_data/bruce_the_gnome/masks/`
- A new `.jsonl` file that references the mask images: `sample_data/bruce_the_gnome/data_masks.jsonl`
## 3 - Review the Generated Masks
Review the generated masks to make sure that the target regions were masked. You may need to adjust the prompt and re-generate the masks to achieve the desired result. Alternatively, you can edit the masks manually. The masks are simply single-channel grayscale images (0=background, 255=foreground).
Here are some examples of the masks that we just generated:
| | |
| - | - |
|  |  |
|  |  |
## 4 - Configuration
Below is the training configuration that we'll use for this tutorial.
Raw config file: [src/invoke_training/sample_configs/sdxl_lora_masks_gnome_1x24gb.yaml](https://github.com/invoke-ai/invoke-training/blob/main/src/invoke_training/sample_configs/sdxl_lora_masks_gnome_1x24gb.yaml).
```yaml title="sdxl_lora_masks_gnome_1x24gb.yaml"
--8<-- "src/invoke_training/sample_configs/sdxl_lora_masks_gnome_1x24gb.yaml"
```
Full documentation of all of the configuration options is here: [LoRA SDXL Config](../../reference/config/pipelines/sdxl_lora.md)
There are few things to note about this training config:
- We set `use_masks: True` in order to use the masks that we generated. This configuration is only compatible with datasets that have mask data.
- The `learning_rate`, `max_train_steps`, `save_every_n_steps`, and `validate_every_n_steps` are all _lower_ than typical for an SDXL LoRA training pipeline. The combination of masking with the small dataset size cause training to progress very quickly. These configuration fields were all adjusted accordingly to avoid overfitting.
## 5 - Start Training
Launch the training run.
```bash
# From inside the invoke-training/ source directory:
invoke-train -c src/invoke_training/sample_configs/sdxl_lora_masks_gnome_1x24gb.yaml
```
Training takes ~30 mins on an NVIDIA RTX 4090.
## 4 - Monitor
In a new terminal, launch Tensorboard to monitor the training run:
```bash
tensorboard --logdir output/
```
Access Tensorboard at [localhost:6006](http://localhost:6006) in your browser.
Sample images will be logged to Tensorboard so that you can see how the model is evolving.
Once training is complete, select the model checkpoint that produces the best visual results. For this tutorial, we'll use the checkpoint from step 300:

*Screenshot of the Tensorboard UI showing the validation images for epoch 300. The validation prompt was: "A stuffed gnome at the beach with a pina colada in its hand.".*
## 6 - Import into InvokeAI
If you haven't already, setup [InvokeAI](https://github.com/invoke-ai/InvokeAI) by following its documentation.
Import your trained LoRA model from the 'Models' tab.
Congratulations, you can now use your new Bruce-the-Gnome model! 🎉
================================================
FILE: docs/guides/stable_diffusion/robocats_finetune_sdxl.md
================================================
# Finetune - SDXL
This tutorial explains how to do a full finetune training run on a [Stable Diffusion XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) base model.
## 0 - Prerequisites
Full model finetuning is more compute-intensive than parameter-efficient finetuning alternatives (e.g. LoRA or Textual Inversion). This tutorial requires a minimum of 24GB of GPU VRAM.
## 1 - Dataset Preparation
For this tutorial, we will use a dataset consisting of 14 images of robocats. The images were auto-captioned. Here are some sample images from the dataset, including their captions:
| | |
| - | - |
|  |  |
| *A white robot with blue eyes and a yellow nose sits on a rock, gazing at the camera, with a pink tree and a white cat in the background.* | *A white cat with green eyes and a blue collar sits on a moss-covered rock in a forest, gazing directly at the camera.* |
## 2 - Configuration
Below is the training configuration that we'll use for this tutorial.
Raw config file: [src/invoke_training/sample_configs/sdxl_finetune_robocats_1x24gb.yaml](https://github.com/invoke-ai/invoke-training/blob/main/src/invoke_training/sample_configs/sdxl_finetune_robocats_1x24gb.yaml).
```yaml title="sdxl_finetune_robocats_1x24gb.yaml"
--8<-- "src/invoke_training/sample_configs/sdxl_finetune_robocats_1x24gb.yaml"
```
Full documentation of all of the configuration options is here: [Finetune SDXL Config](../../reference/config/pipelines/sdxl_finetune.md)
!!! note "`save_checkpoint_format`"
Note the `save_checkpoint_format` setting, as it is unique to full finetune training. For this tutorial, we have set `save_checkpoint_format: trained_only_diffusers`. This means that only the UNet model will be saved at each checkpoint, and it will be saved in diffusers format. This setting conserves disk space by not redundantly saving the non-trained weights. Before these UNet checkpoints can be used, they must either be merged into a full model, or extracted into a LoRA. Instructions for this follow later in this tutorial. A full explanation of the `save_checkpoint_format` options can be found here: [save_checkpoint_format][invoke_training.pipelines.stable_diffusion_xl.finetune.config.SdxlFinetuneConfig.save_checkpoint_format].
## 3 - Start Training
Launch the training run.
```bash
# From inside the invoke-training/ source directory:
invoke-train -c src/invoke_training/sample_configs/sdxl_finetune_robocats_1x24gb.yaml
```
Training takes ~45 mins on an NVIDIA RTX 4090.
## 4 - Monitor
In a new terminal, launch Tensorboard to monitor the training run:
```bash
tensorboard --logdir output/
```
Access Tensorboard at [localhost:6006](http://localhost:6006) in your browser.
Sample images will be logged to Tensorboard so that you can see how the model is evolving.
Once training is complete, select the model checkpoint that produces the best visual results.
## 5 - Prepare the trained model
Since we set `save_checkpoint_format: trained_only_diffusers`, our selected checkpoint only contains the UNet model weights. The checkpoint has the following directory structure:
```bash
output/robocats/sdxl_finetune/1715373799.3558652/checkpoints/checkpoint-epoch_00000500-step_00002000/
└── unet
├── config.json
└── diffusion_pytorch_model.safetensors
```
Before we can use this trained model, we must do one of the following:
- Prepare a full diffusers checkpoint with the new UNet weights.
- Extract the difference between the trained UNet and the original UNet into a LoRA model.
### Prepare a full model
If we want to use our finetuned UNet model, we must first package it into a format supported by applications like InvokeAI.
In this section we will assume that we have a full SDXL base model in diffusers format. It should have a directory structure like the one shown before. We simply need to replace the `unet/` directory with the one from our selected training checkpoint:
```bash
stable-diffusion-xl-base-1.0
├── model_index.json
├── scheduler
│ └── scheduler_config.json
├── text_encoder
│ ├── config.json
│ └── model.fp16.safetensors
├── text_encoder_2
│ ├── config.json
│ └── model.fp16.safetensors
├── tokenizer
│ ├── merges.txt
│ ├── special_tokens_map.json
│ ├── tokenizer_config.json
│ └── vocab.json
├── tokenizer_2
│ ├── merges.txt
│ ├── special_tokens_map.json
│ ├── tokenizer_config.json
│ └── vocab.json
├── unet # <-- Replace this directory with the trained checkpoint.
│ ├── config.json
│ └── diffusion_pytorch_model.fp16.safetensors
├── vae
│ ├── config.json
│ └── diffusion_pytorch_model.fp16.safetensors
└── vae_1_0
└── diffusion_pytorch_model.fp16.safetensors
```
!!! note "diffusers variants (e.g. 'fp16')"
In this example, notice that the `*.safetensors` files contain `.fp16.` in their filenames. Hugging Face refers to this identifier as a "variant". It is used to select between multiple model variants in their model hub.
In this case, we should add the `.fp16.` variant tag to our finetuned UNet for consistency with the rest of the model. Since we set `save_dtype: float16` in our training config, the `fp16` tag accurately represents the precision of our UNet model file.
### Extract a LoRA model
An alternative to using the finetuned UNet model directly is to compare it against the original and extract the difference as a LoRA model. The resultant LoRA has a much smaller file size and can be applied to any base model. But, the LoRA model is a *lossy* representation of the difference, so some quality degradation is expected.
To extract a LoRA model, run the following command:
```bash
python src/invoke_training/model_merge/scripts/extract_lora_from_model_diff.py \
--model-type SDXL \
--model-orig path/to/stable-diffusion-xl-base-1.0 \
--model-tuned output/robocats/sdxl_finetune/1715373799.3558652/checkpoints/checkpoint-epoch_00000500-step_00002000 \
--save-to robocats_lora_step_2000.safetensors \
--lora-rank 32
```
## 6 - Import into InvokeAI
If you haven't already, setup [InvokeAI](https://github.com/invoke-ai/InvokeAI) by following its documentation.
Import your finetuned diffusers model or your extracted LoRA from the 'Models' tab.
Congratulations, you can now use your new robocat model! 🎉
## 7 - Comparison: Finetune vs. LoRA Extraction
As noted earlier, the LoRA extraction process is lossy for a number of reasons.
Below, we compare images generated with the same seed and prompt for 3 different model configurations.
Prompt: *In robocat style, a robotic lion in the jungle.*
| SDXL Base 1.0 | w/ Finetuned UNet | w/ Extracted LoRA |
| - | - | - |
|  |  | 
================================================
FILE: docs/guides/stable_diffusion/textual_inversion_sdxl.md
================================================
# Textual Inversion - SDXL
This tutorial walks through a [Textual Inversion](https://arxiv.org/abs/2208.01618) training run with a [Stable Diffusion XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) base model.
## 1 - Dataset
For this tutorial, we'll use a dataset consisting of 4 images of Bruce the Gnome:
| | |
| - | - |
|  |  |
|  |  |
This sample dataset is included in the invoke-training repo under [sample_data/bruce_the_gnome](https://github.com/invoke-ai/invoke-training/tree/main/sample_data/bruce_the_gnome).
Here are a few tips for preparing a Textual Inversion dataset:
- Aim for 4 to 50 images of your concept (object / style). The optimal number depends on many factors, and can be much higher than this for some use cases.
- Vary all of the image features that you *don't* want your TI embedding to contain (e.g. background, pose, lighting, etc.).
## 2 - Configuration
Below is the training configuration that we'll use for this tutorial.
Raw config file: [src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml](https://github.com/invoke-ai/invoke-training/blob/main/src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml).
Full config reference docs: [Textual Inversion SDXL Config](../../reference/config/pipelines/sdxl_textual_inversion.md)
```yaml title="sdxl_textual_inversion_gnome_1x24gb.yaml"
--8<-- "src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml"
```
## 3 - Start Training
[Install invoke-training](../../get-started/installation.md), if you haven't already.
Launch the Textual Inversion training pipeline:
```bash
# From inside the invoke-training/ source directory:
invoke-train -c src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml
```
Training takes ~40 mins on an NVIDIA RTX 4090.
## 4 - Monitor
In a new terminal, launch Tensorboard to monitor the training run:
```bash
tensorboard --logdir output/
```
Access Tensorboard at [localhost:6006](http://localhost:6006) in your browser.
Sample images will be logged to Tensorboard so that you can see how the Textual Inversion embedding is evolving.
Once training is complete, select the epoch that produces the best visual results.
For this tutorial, we'll choose epoch 500:

*Screenshot of the Tensorboard UI showing the validation images for epoch 500.*
## 5 - Transfer to InvokeAI
If you haven't already, setup [InvokeAI](https://github.com/invoke-ai/InvokeAI) by following its documentation.
Copy the selected TI embedding into your `${INVOKEAI_ROOT}/autoimport/embedding/` directory. For example:
```bash
cp output/sdxl_ti_bruce_the_gnome/1702587511.2273068/checkpoint_epoch-00000500.safetensors ${INVOKEAI_ROOT}/autoimport/embedding/bruce_the_gnome.safetensors
```
Note that we renamed the file to `bruce_the_gnome.safetensors`. You can choose any file name, but this will become the token used to reference your embedding. So, in our case, we can refer to our new embedding by including `` in our prompts.
Launch Invoke AI and you can now use your new `bruce_the_gnome` TI embedding! 🎉

*Example image generated with the prompt "`a photo of at the park`".*
================================================
FILE: docs/index.md
================================================
# invoke-training
A library for training custom Stable Diffusion models (fine-tuning, LoRA training, textual inversion, etc.) that can be used in [InvokeAI](https://github.com/invoke-ai/InvokeAI).
## Documentation
The documentation is organized as follows:
- [Get Started](get-started/installation.md): Install `invoke-training` and run your first training pipeline.
- [Guides](guides/dataset_formats.md): Full tutorials for running popular training pipelines.
- [Config Reference](reference/config/index.md): Reference documentation for all supported training configuration options.
- [Contributing](contributing/development_environment.md): Information for `invoke-training` developers.
================================================
FILE: docs/reference/config/index.md
================================================
# Config Reference
This section contains reference documentation for the `invoke-training` configuration schema (i.e. documentation for all of the supported training options).
This documentation uses python typing semantics to define the configuration schema. Typically the configuration for a training run is specified in a YAML file and then parse against this schema.
================================================
FILE: docs/reference/config/pipelines/sd_lora.md
================================================
# `SdLoraConfig`
::: invoke_training.pipelines.stable_diffusion.lora.config.SdLoraConfig
options:
members:
- type
::: invoke_training.pipelines.stable_diffusion.lora.config.SdLoraConfig
options:
filters:
- "!^model_config"
- "!^type"
================================================
FILE: docs/reference/config/pipelines/sd_textual_inversion.md
================================================
# `SdTextualInversionConfig`
::: invoke_training.pipelines.stable_diffusion.textual_inversion.config.SdTextualInversionConfig
options:
members:
- type
::: invoke_training.pipelines.stable_diffusion.textual_inversion.config.SdTextualInversionConfig
options:
filters:
- "!^model_config"
- "!^type"
================================================
FILE: docs/reference/config/pipelines/sdxl_finetune.md
================================================
# `SdxlFinetuneConfig`
::: invoke_training.pipelines.stable_diffusion_xl.finetune.config.SdxlFinetuneConfig
options:
members:
- type
::: invoke_training.pipelines.stable_diffusion_xl.finetune.config.SdxlFinetuneConfig
options:
filters:
- "!^model_config"
- "!^type"
================================================
FILE: docs/reference/config/pipelines/sdxl_lora.md
================================================
# `SdxlLoraConfig`
::: invoke_training.pipelines.stable_diffusion_xl.lora.config.SdxlLoraConfig
options:
members:
- type
::: invoke_training.pipelines.stable_diffusion_xl.lora.config.SdxlLoraConfig
options:
filters:
- "!^model_config"
- "!^type"
================================================
FILE: docs/reference/config/pipelines/sdxl_lora_and_textual_inversion.md
================================================
# `SdxlLoraAndTextualInversionConfig`
::: invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config.SdxlLoraAndTextualInversionConfig
options:
members:
- type
::: invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config.SdxlLoraAndTextualInversionConfig
options:
filters:
- "!^model_config"
- "!^type"
================================================
FILE: docs/reference/config/pipelines/sdxl_textual_inversion.md
================================================
# `SdxlTextualInversionConfig`
Below is a sample yaml config file for Textual Inversion SDXL training ([raw file](https://github.com/invoke-ai/invoke-training/blob/main/src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml)). All of the configuration fields are explained in detail on this page.
```yaml title="sdxl_textual_inversion_gnome_1x24gb.yaml"
--8<-- "src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml"
```
::: invoke_training.pipelines.stable_diffusion_xl.textual_inversion.config.SdxlTextualInversionConfig
options:
members:
- type
::: invoke_training.pipelines.stable_diffusion_xl.textual_inversion.config.SdxlTextualInversionConfig
options:
filters:
- "!^model_config"
- "!^type"
================================================
FILE: docs/reference/config/shared/data/data_loader_config.md
================================================
::: invoke_training.config.data.data_loader_config
options:
filters:
- "!^model_config"
================================================
FILE: docs/reference/config/shared/data/dataset_config.md
================================================
::: invoke_training.config.data.dataset_config
options:
filters:
- "!^model_config"
================================================
FILE: docs/reference/config/shared/optimizer_config.md
================================================
::: invoke_training.config.optimizer.optimizer_config
options:
filters:
- "!^model_config"
================================================
FILE: docs/templates/python/material/labels.html
================================================
================================================
FILE: mkdocs.yml
================================================
site_name: invoke-training
site_url: https://invoke-ai.github.io/invoke-training/
repo_name: invoke-ai/invoke-training
repo_url: https://github.com/invoke-ai/invoke-training
theme:
name: material
features:
- navigation.tabs
- navigation.indexes
- navigation.sections
- content.code.copy
markdown_extensions:
- admonition
- sane_lists
- pymdownx.highlight:
anchor_linenums: true
line_spans: __span
pygments_lang_class: true
- pymdownx.inlinehilite
- pymdownx.snippets
- pymdownx.superfences
nav:
- Welcome: index.md
- Get Started:
- get-started/installation.md
- get-started/quick-start.md
- Guides:
- Dataset Formats: guides/dataset_formats.md
- Model Merging: guides/model_merge.md
- Stable Diffusion Training:
- guides/stable_diffusion/robocats_finetune_sdxl.md
- guides/stable_diffusion/gnome_lora_masks_sdxl.md
- guides/stable_diffusion/textual_inversion_sdxl.md
- guides/stable_diffusion/dpo_lora_sd.md
- YAML Config Reference:
- reference/config/index.md
- pipelines:
- SD LoRA Config: reference/config/pipelines/sd_lora.md
- SD Textual Inversion Config: reference/config/pipelines/sd_textual_inversion.md
- SDXL LoRA Config: reference/config/pipelines/sdxl_lora.md
- SDXL Textual Inversion Config: reference/config/pipelines/sdxl_textual_inversion.md
- SDXL LoRA and Textual Inversion Config: reference/config/pipelines/sdxl_lora_and_textual_inversion.md
- SDXL Finetune Config: reference/config/pipelines/sdxl_finetune.md
- shared:
- data_loader_config: reference/config/shared/data/data_loader_config.md
- dataset_config: reference/config/shared/data/dataset_config.md
- optimizer_config: reference/config/shared/optimizer_config.md
- Contributing:
- contributing/development_environment.md
- contributing/directory_structure.md
- contributing/tests.md
- contributing/documentation.md
plugins:
- search
- mkdocstrings:
default_handler: python
custom_templates: docs/templates
handlers:
python:
options:
show_root_heading: false
show_root_toc_entry: false
show_bases: false
show_source: false
show_if_no_docstring: true
inherited_members: true
annotations_path: brief
separate_signature: true
show_signature_annotations: true
members_order: source
================================================
FILE: pyproject.toml
================================================
[build-system]
requires = ["setuptools>=65.5", "pip>=22.3"]
build-backend = "setuptools.build_meta"
[project]
name = "invoke-training"
version = "0.0.1"
authors = [{ name = "The Invoke AI Team", email = "ryan@invoke.ai" }]
description = "A library for Stable Diffusion model training."
readme = "README.md"
requires-python = ">=3.10"
license = { text = "Apache-2.0" }
classifiers = [
"Programming Language :: Python :: 3",
"Operating System :: OS Independent",
]
dependencies = [
"accelerate",
"datasets~=2.14.3",
"diffusers[torch]",
"einops",
"fastapi",
"gradio",
"invokeai>=5.10.0a1",
"numpy<2.0.0",
"omegaconf",
"peft~=0.11.1",
"pillow",
"prodigyopt",
"pydantic",
"pyyaml",
"safetensors",
"tensorboard",
"torch",
"torchvision",
"tqdm",
"transformers",
"uvicorn[standard]",
]
[project.optional-dependencies]
"xformers" = ["xformers>=0.0.28.post1; sys_platform!='darwin'"]
"bitsandbytes" = ["bitsandbytes>=0.43.1; sys_platform!='darwin'"]
"test" = [
"mkdocs",
"mkdocs-material",
"mkdocstrings[python]",
"pre-commit~=3.3.3",
"pytest~=7.4.0",
"ruff~=0.11.2",
"ruff-lsp",
]
[project.scripts]
"invoke-train" = "invoke_training.scripts.invoke_train:main"
"invoke-train-ui" = "invoke_training.scripts.invoke_train_ui:main"
"invoke-generate-images" = "invoke_training.scripts.invoke_generate_images:main"
"invoke-visualize-data-loading" = "invoke_training.scripts.invoke_visualize_data_loading:main"
[project.urls]
"Homepage" = "https://github.com/invoke-ai/invoke-training"
"Discord" = "https://discord.gg/ZmtBAhwWhy"
[tool.setuptools.package-data]
"invoke_training.assets" = ["*.png"]
"invoke_training.sample_configs" = ["**/*.yaml"]
"invoke_training.ui" = ["*.html"]
[tool.ruff]
src = ["src"]
lint.select = ["E", "F", "W", "C9", "N8", "I"]
target-version = "py39"
line-length = 120
[tool.pytest.ini_options]
addopts = "--strict-markers"
markers = [
"cuda: marks tests that require a CUDA GPU",
"loads_model: marks tests that require a model (or data) from the HF hub",
]
================================================
FILE: sample_data/bruce_the_gnome/data.jsonl
================================================
{"image": "001.png", "text": "A stuffed gnome sits on a wooden floor, facing right with a gray couch in the background."}
{"image": "002.png", "text": "A stuffed gnome stands on a black tiled floor, with a silver refrigerator and white wall in the background."}
{"image": "004.png", "text": "A stuffed gnome sits on a white marble floor, photorealistic."}
{"image": "003.png", "text": "A stuffed gnome sits on a gray tiled floor, facing the camera."}
================================================
FILE: src/invoke_training/__init__.py
================================================
================================================
FILE: src/invoke_training/_shared/__init__.py
================================================
================================================
FILE: src/invoke_training/_shared/accelerator/__init__.py
================================================
================================================
FILE: src/invoke_training/_shared/accelerator/accelerator_utils.py
================================================
import logging
import os
from typing import Literal
import datasets
import diffusers
import torch
import transformers
from accelerate import Accelerator
from accelerate.logging import MultiProcessAdapter, get_logger
from accelerate.utils import ProjectConfiguration
def initialize_accelerator(
out_dir: str, gradient_accumulation_steps: int, mixed_precision: str, log_with: str
) -> Accelerator:
"""Configure Hugging Face accelerate and return an Accelerator.
Args:
out_dir (str): The output directory where results will be written.
gradient_accumulation_steps (int): Forwarded to accelerat.Accelerator(...).
mixed_precision (str): Forwarded to accelerate.Accelerator(...).
log_with (str): Forwarded to accelerat.Accelerator(...)
Returns:
Accelerator
"""
accelerator_project_config = ProjectConfiguration(
project_dir=out_dir,
logging_dir=os.path.join(out_dir, "logs"),
)
return Accelerator(
project_config=accelerator_project_config,
gradient_accumulation_steps=gradient_accumulation_steps,
mixed_precision=mixed_precision,
log_with=log_with,
)
def initialize_logging(logger_name: str, accelerator: Accelerator) -> MultiProcessAdapter:
"""Configure logging.
Returns an accelerate logger with multi-process logging support. Logging is configured to be more verbose on the
main process. Non-main processes only log at error level for Hugging Face libraries (datasets, transformers,
diffusers).
Args:
accelerator (Accelerator): The Accelerator to configure.
Returns:
MultiProcessAdapter: _description_
"""
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
# Only log errors from non-main processes.
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
return get_logger(logger_name)
def get_mixed_precision_dtype(accelerator: Accelerator):
"""Extract torch.dtype from Accelerator config.
Args:
accelerator (Accelerator): The Hugging Face Accelerator.
Raises:
NotImplementedError: If the accelerator's mixed_precision configuration is not recognized.
Returns:
torch.dtype: The weight type inferred from the accelerator mixed_precision configuration.
"""
weight_dtype: torch.dtype = torch.float32
if accelerator.mixed_precision is None or accelerator.mixed_precision == "no":
weight_dtype = torch.float32
elif accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
else:
raise NotImplementedError(f"mixed_precision mode '{accelerator.mixed_precision}' is not yet supported.")
return weight_dtype
def get_dtype_from_str(dtype_str: Literal["float16", "bfloat16", "float32"]) -> torch.dtype:
if dtype_str == "float16":
return torch.float16
elif dtype_str == "bfloat16":
return torch.bfloat16
elif dtype_str == "float32":
return torch.float32
else:
raise ValueError(f"Unsupported dtype: {dtype_str}")
================================================
FILE: src/invoke_training/_shared/checkpoints/__init__.py
================================================
================================================
FILE: src/invoke_training/_shared/checkpoints/checkpoint_tracker.py
================================================
import os
import shutil
import typing
class CheckpointTracker:
"""A utility class for managing checkpoint paths.
Manages checkpoint paths of the following forms:
- Checkpoint directories: `{base_dir}/{prefix}-epoch_{num_epochs}-step_{num_steps}`
- Checkpoint files: `{base_dir}/{prefix}-epoch_{num_epochs}-step_{num_steps}{extension}`
"""
def __init__(
self,
base_dir: str,
prefix: str,
extension: typing.Optional[str] = None,
max_checkpoints: typing.Optional[int] = None,
index_padding: int = 8,
):
"""Initialize a CheckpointTracker.
Args:
base_dir (str): The base checkpoint directory.
prefix (str): A prefix applied to every checkpoint.
extension (str, optional): If set, this is the file extension that will be applied to all checkpoints
(usually one of ".pt", ".ckpt", or ".safetensors"). If None, then it will be assumed that we are
managing checkpoint directories rather than files.
max_checkpoints (typing.Optional[int], optional): The maximum number of checkpoints that should exist in
base_dir.
index_padding (int, optional): The length of the zero-padded epoch/step counts in the generated checkpoint
names. E.g. index_padding=8 would produce checkpoint paths like
"base_dir/prefix-epoch_00000001-step_00000001.ckpt".
Raises:
ValueError: If extension is provided, but it doesn not start with a '.'.
"""
if extension is not None and not extension.startswith("."):
raise ValueError(f"extension='{extension}' must start with a '.'.")
self._base_dir = base_dir
self._prefix = prefix
self._extension = extension
self._max_checkpoints = max_checkpoints
self._index_padding = index_padding
def prune(self, buffer_num: int = 1) -> int:
"""Delete checkpoint files and directories so that there are at most `max_checkpoints - buffer_num` checkpoints
remaining. The checkpoints with the lowest step counts will be deleted.
Args:
buffer_num (int, optional): The number below `max_checkpoints` to 'free-up'.
Returns:
int: The number of checkpoints deleted.
"""
if self._max_checkpoints is None:
return 0
checkpoints = os.listdir(self._base_dir)
checkpoints = [p for p in checkpoints if p.startswith(self._prefix)]
checkpoints = sorted(
checkpoints,
key=lambda x: int(os.path.splitext(x)[0].split("-step_")[-1]),
)
num_to_remove = len(checkpoints) - (self._max_checkpoints - buffer_num)
if num_to_remove > 0:
checkpoints_to_remove = checkpoints[:num_to_remove]
for checkpoint_to_remove in checkpoints_to_remove:
checkpoint_to_remove = os.path.join(self._base_dir, checkpoint_to_remove)
if os.path.isfile(checkpoint_to_remove):
# Delete checkpoint file.
os.remove(checkpoint_to_remove)
else:
# Delete checkpoint directory.
shutil.rmtree(checkpoint_to_remove)
return max(0, num_to_remove)
def get_path(self, epoch: int, step: int) -> str:
"""Get the checkpoint path for index `idx`.
Args:
epoch (int): The number of completed epochs.
step (int): The number of completed training steps.
Returns:
str: The checkpoint path.
"""
suffix = self._extension or ""
return os.path.join(
self._base_dir,
f"{self._prefix.strip()}-epoch_{epoch:0>{self._index_padding}}-step_{step:0>{self._index_padding}}{suffix}",
)
================================================
FILE: src/invoke_training/_shared/checkpoints/lora_checkpoint_utils.py
================================================
from pathlib import Path
import peft
import torch
def save_multi_model_peft_checkpoint(checkpoint_dir: Path | str, models: dict[str, peft.PeftModel]):
"""Save a dict of PeftModels to a checkpoint directory.
The `models` dict keys are used as the subdirectories for each individual model.
`load_multi_model_peft_checkpoint(...)` can be used to load the resultant checkpoint.
"""
checkpoint_dir = Path(checkpoint_dir)
for model_key, peft_model in models.items():
assert isinstance(peft_model, peft.PeftModel)
# HACK(ryand): PeftModel.save_pretrained(...) expects the config to have a "_name_or_path" entry. For now, we
# set this to None here. This should be fixed upstream in PEFT.
if (
hasattr(peft_model, "config")
and isinstance(peft_model.config, dict)
and "_name_or_path" not in peft_model.config
):
peft_model.config["_name_or_path"] = None
peft_model.save_pretrained(str(checkpoint_dir / model_key))
def load_multi_model_peft_checkpoint(
checkpoint_dir: Path | str,
models: dict[str, torch.nn.Module],
is_trainable: bool = False,
raise_if_subdir_missing: bool = True,
) -> dict[str, torch.nn.Module]:
"""Load a multi-model PEFT checkpoint that was saved with `save_multi_model_peft_checkpoint(...)`."""
checkpoint_dir = Path(checkpoint_dir)
assert checkpoint_dir.exists()
out_models = {}
for model_key, model in models.items():
dir_path: Path = checkpoint_dir / model_key
if dir_path.exists():
out_models[model_key] = peft.PeftModel.from_pretrained(model, dir_path, is_trainable=is_trainable)
else:
if raise_if_subdir_missing:
raise ValueError(f"'{dir_path}' does not exist.")
else:
# Pass through the model unchanged.
out_models[model_key] = model
return out_models
# This implementation is based on
# https://github.com/huggingface/peft/blob/8665e2b5719faa4e4b91749ddec09442927b53e0/examples/lora_dreambooth/convert_peft_sd_lora_to_kohya_ss.py#L20
def _convert_peft_state_dict_to_kohya_state_dict(
lora_config: peft.LoraConfig,
peft_state_dict: dict[str, torch.Tensor],
prefix: str,
dtype: torch.dtype,
) -> dict[str, torch.Tensor]:
kohya_ss_state_dict = {}
for peft_key, weight in peft_state_dict.items():
kohya_key = peft_key.replace("base_model.model", prefix)
kohya_key = kohya_key.replace("lora_A", "lora_down")
kohya_key = kohya_key.replace("lora_B", "lora_up")
kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
kohya_ss_state_dict[kohya_key] = weight.to(dtype)
# Set alpha parameter
if "lora_down" in kohya_key:
alpha_key = f"{kohya_key.split('.')[0]}.alpha"
kohya_ss_state_dict[alpha_key] = torch.tensor(lora_config.lora_alpha).to(dtype)
return kohya_ss_state_dict
def _convert_peft_models_to_kohya_state_dict(
kohya_prefixes: list[str], models: list[peft.PeftModel]
) -> dict[str, torch.Tensor]:
kohya_state_dict = {}
default_adapter_name = "default"
for kohya_prefix, peft_model in zip(kohya_prefixes, models, strict=True):
lora_config = peft_model.peft_config[default_adapter_name]
assert isinstance(lora_config, peft.LoraConfig)
peft_state_dict = peft.get_peft_model_state_dict(peft_model, adapter_name=default_adapter_name)
kohya_state_dict.update(
_convert_peft_state_dict_to_kohya_state_dict(
lora_config=lora_config,
peft_state_dict=peft_state_dict,
prefix=kohya_prefix,
dtype=torch.float32,
)
)
return kohya_state_dict
================================================
FILE: src/invoke_training/_shared/checkpoints/serialization.py
================================================
import typing
from pathlib import Path
import safetensors.torch
import torch
def save_state_dict(state_dict: typing.Dict[str, torch.Tensor], out_file: typing.Union[Path, str]):
"""Save a state_dict to a file.
Both safetensors and torch formats are supported. The format is inferred from the `out_file` extension.
Supported extensions:
- ".ckpt" -> torch
- ".pt" -> torch
- ".safetensors -> safetensors
Args:
state_dict (typing.Dict[str, torch.Tensor]): The state_dict to save.
out_file (Path | str): The output file to save to.
Raises:
ValueError: If the `out_file` has an unsupported file extension.
"""
out_file = Path(out_file)
if out_file.suffix == ".ckpt" or out_file.suffix == ".pt":
torch.save(state_dict, out_file)
elif out_file.suffix == ".safetensors":
safetensors.torch.save_file(state_dict, out_file)
else:
raise ValueError(f"Unsupported file extension: '{out_file.suffix}'.")
def load_state_dict(in_file: typing.Union[Path, str]) -> typing.Dict[str, torch.Tensor]:
"""Load a state_dict from a file.
Both safetensors and torch formats are supported. The format is inferred from the `in_file` extension.
Supported extensions:
- ".ckpt" -> torch
- ".pt" -> torch
- ".safetensors -> safetensors
Args:
in_file (Path | str): The input file to load from.
Raises:
ValueError: If the `in_file` has an unsupported file extension.
Returns:
typing.Dict[str, torch.Tensor]: The loaded state_dict.
"""
in_file = Path(in_file)
if in_file.suffix == ".ckpt" or in_file.suffix == ".pt":
return torch.load(in_file)
elif in_file.suffix == ".safetensors":
return safetensors.torch.load_file(in_file)
else:
raise ValueError(f"Unsupported file extension: '{in_file.suffix}'.")
================================================
FILE: src/invoke_training/_shared/data/ARCHITECTURE.md
================================================
# Dataset Architecture
Dataset handling is split into 3 layers of abstraction: Datasets, Transforms, and DataLoaders. Each is explained in more detail below.
## Datasets
Datasets implement the [torch.utils.data.Dataset](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#creating-a-custom-dataset-for-your-files) interface.
Most dataset classes act as an abstraction over a specific dataset format.
## Transforms
Transforms are functions applied to data loaded by Datasets. For example, the `SDImageTransform` implements image augmentations for Stable Diffusion training.
Transforms are kept separate from the underlying datasets for several reasons:
- It is easier to write tests for isolated transforms.
- Modular transforms can often be re-used for multiple base datasets.
- Modular transforms make it easy to customize datasets for different situations. For example, you may want to wrap a dataset with one set of transforms initially to populate a cache, and then with a different set of transforms to read from the cache.
Transforms are applied to a dataset via the `TransformDataset` class.
## DataLoaders
The dataset classes (with composed transforms) are wrapped in a `torch.utils.data.DataLoader` that handles batch collation, multi-processing, etc.
================================================
FILE: src/invoke_training/_shared/data/__init__.py
================================================
================================================
FILE: src/invoke_training/_shared/data/data_loaders/__init__.py
================================================
================================================
FILE: src/invoke_training/_shared/data/data_loaders/dreambooth_sd_dataloader.py
================================================
import typing
from torch.utils.data import ConcatDataset, DataLoader
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from invoke_training._shared.data.data_loaders.image_caption_sd_dataloader import (
build_aspect_ratio_bucket_manager,
sd_image_caption_collate_fn,
)
from invoke_training._shared.data.datasets.image_dir_dataset import ImageDirDataset
from invoke_training._shared.data.datasets.transform_dataset import TransformDataset
from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import AspectRatioBucketBatchSampler
from invoke_training._shared.data.samplers.batch_offset_sampler import BatchOffsetSampler
from invoke_training._shared.data.samplers.concat_sampler import ConcatSampler
from invoke_training._shared.data.samplers.interleaved_sampler import InterleavedSampler
from invoke_training._shared.data.samplers.offset_sampler import OffsetSampler
from invoke_training._shared.data.transforms.constant_field_transform import ConstantFieldTransform
from invoke_training._shared.data.transforms.drop_field_transform import DropFieldTransform
from invoke_training._shared.data.transforms.load_cache_transform import LoadCacheTransform
from invoke_training._shared.data.transforms.sd_image_transform import SDImageTransform
from invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache
from invoke_training.config.data.data_loader_config import DreamboothSDDataLoaderConfig
def build_dreambooth_sd_dataloader(
config: DreamboothSDDataLoaderConfig,
batch_size: int,
text_encoder_output_cache_dir: typing.Optional[str] = None,
text_encoder_cache_field_to_output_field: typing.Optional[dict[str, str]] = None,
vae_output_cache_dir: typing.Optional[str] = None,
shuffle: bool = True,
sequential_batching: bool = False,
) -> DataLoader:
"""Construct a DataLoader for a DreamBooth dataset for Stable Diffusion XL.
Args:
config (DreamboothSDDataLoaderConfig):
batch_size (int):
text_encoder_output_cache_dir (str, optional): The directory where text encoder outputs are cached and should be
loaded from.
vae_output_cache_dir (str, optional): The directory where VAE outputs are cached and should be loaded from. If
set, then the image augmentation transforms will be skipped, and the image will not be copied to VRAM.
shuffle (bool, optional): Whether to shuffle the dataset order.
sequential_batching (bool, optional): If True, the internal dataset will be processed sequentially rather than
interleaving class and instance examples. This is intended to be used when processing the entire dataset for
caching purposes. Defaults to False.
Returns:
DataLoader
"""
# Prepare instance dataset.
base_instance_dataset = ImageDirDataset(
config.instance_dataset.dataset_dir,
id_prefix="instance_",
keep_in_memory=config.instance_dataset.keep_in_memory,
)
instance_dataset = TransformDataset(
base_instance_dataset,
[
ConstantFieldTransform("caption", config.instance_caption),
ConstantFieldTransform("loss_weight", 1.0),
],
)
datasets = [instance_dataset]
# Prepare class dataset.
base_class_dataset = None
class_dataset = None
if config.class_dataset is not None:
base_class_dataset = ImageDirDataset(
config.class_dataset.dataset_dir, id_prefix="class_", keep_in_memory=config.class_dataset.keep_in_memory
)
class_dataset = TransformDataset(
base_class_dataset,
[
ConstantFieldTransform("caption", config.class_caption),
ConstantFieldTransform("loss_weight", config.class_data_loss_weight),
],
)
datasets.append(class_dataset)
# Merge instance dataset and class dataset.
merged_dataset = ConcatDataset(datasets)
# Initialize either the fixed target resolution or aspect ratio buckets.
target_resolution = None
aspect_ratio_bucket_manager = None
instance_sampler = None
class_sampler = None
if config.aspect_ratio_buckets is None:
target_resolution = config.resolution
# TODO(ryand): Provide a seeded generator.
instance_sampler = RandomSampler(instance_dataset) if shuffle else SequentialSampler(instance_dataset)
if base_class_dataset is not None:
class_sampler = RandomSampler(class_dataset) if shuffle else SequentialSampler(class_dataset)
class_sampler = OffsetSampler(class_sampler, offset=len(base_instance_dataset))
else:
aspect_ratio_bucket_manager = build_aspect_ratio_bucket_manager(config=config.aspect_ratio_buckets)
# TODO(ryand): Drill-down the seed parameter rather than hard-coding to 0 here.
instance_sampler = AspectRatioBucketBatchSampler.from_image_sizes(
bucket_manager=aspect_ratio_bucket_manager,
image_sizes=base_instance_dataset.get_image_dimensions(),
batch_size=batch_size,
shuffle=shuffle,
seed=0,
)
if base_class_dataset is not None:
class_sampler = AspectRatioBucketBatchSampler.from_image_sizes(
bucket_manager=aspect_ratio_bucket_manager,
image_sizes=base_class_dataset.get_image_dimensions(),
batch_size=batch_size,
shuffle=shuffle,
seed=0,
)
class_sampler = BatchOffsetSampler(class_sampler, offset=len(base_instance_dataset))
# Add transforms to the merged dataset.
all_transforms = []
if vae_output_cache_dir is None:
all_transforms.append(
SDImageTransform(
image_field_names=["image"],
fields_to_normalize_to_range_minus_one_to_one=["image"],
resolution=target_resolution,
aspect_ratio_bucket_manager=aspect_ratio_bucket_manager,
center_crop=config.center_crop,
random_flip=config.random_flip,
)
)
else:
vae_cache = TensorDiskCache(vae_output_cache_dir)
all_transforms.append(
LoadCacheTransform(
cache=vae_cache,
cache_key_field="id",
cache_field_to_output_field={
"vae_output": "vae_output",
"original_size_hw": "original_size_hw",
"crop_top_left_yx": "crop_top_left_yx",
},
)
)
# We drop the image to avoid having to either convert from PIL, or handle PIL batch collation.
all_transforms.append(DropFieldTransform("image"))
if text_encoder_output_cache_dir is not None:
assert text_encoder_cache_field_to_output_field is not None
text_encoder_cache = TensorDiskCache(text_encoder_output_cache_dir)
all_transforms.append(
LoadCacheTransform(
cache=text_encoder_cache,
cache_key_field="id",
cache_field_to_output_field=text_encoder_cache_field_to_output_field,
)
)
merged_dataset = TransformDataset(merged_dataset, all_transforms)
# Choose between sequential vs. interleaved merging of the instance and class samplers.
# Sequential sampling is typically used to populate a cache, because it guarantees that all examples will be
# included in an epoch.
samplers = [instance_sampler]
if class_sampler is not None:
samplers.append(class_sampler)
if sequential_batching:
sampler = ConcatSampler(samplers)
else:
sampler = InterleavedSampler(samplers)
if config.aspect_ratio_buckets is None:
return DataLoader(
merged_dataset,
sampler=sampler,
collate_fn=sd_image_caption_collate_fn,
batch_size=batch_size,
num_workers=config.dataloader_num_workers,
)
else:
# If config.aspect_ratio_buckets is not None, then we are using a batch sampler.
return DataLoader(
merged_dataset,
batch_sampler=sampler,
collate_fn=sd_image_caption_collate_fn,
num_workers=config.dataloader_num_workers,
)
================================================
FILE: src/invoke_training/_shared/data/data_loaders/image_caption_flux_dataloader.py
================================================
import typing
from torch.utils.data import DataLoader
from invoke_training._shared.data.data_loaders.image_caption_sd_dataloader import (
build_aspect_ratio_bucket_manager,
)
from invoke_training._shared.data.data_loaders.image_caption_sd_dataloader import (
sd_image_caption_collate_fn as flux_image_caption_collate_fn,
)
from invoke_training._shared.data.datasets.build_dataset import (
build_hf_hub_image_caption_dataset,
build_image_caption_dir_dataset,
build_image_caption_jsonl_dataset,
)
from invoke_training._shared.data.datasets.transform_dataset import TransformDataset
from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import (
AspectRatioBucketBatchSampler,
)
from invoke_training._shared.data.transforms.caption_prefix_transform import CaptionPrefixTransform
from invoke_training._shared.data.transforms.drop_field_transform import DropFieldTransform
from invoke_training._shared.data.transforms.flux_image_transform import FluxImageTransform
from invoke_training._shared.data.transforms.load_cache_transform import LoadCacheTransform
from invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache
from invoke_training.config.data.data_loader_config import ImageCaptionFluxDataLoaderConfig
from invoke_training.config.data.dataset_config import (
HFHubImageCaptionDatasetConfig,
ImageCaptionDirDatasetConfig,
ImageCaptionJsonlDatasetConfig,
)
def build_image_caption_flux_dataloader( # noqa: C901
config: ImageCaptionFluxDataLoaderConfig,
batch_size: int,
use_masks: bool = False,
text_encoder_output_cache_dir: typing.Optional[str] = None,
text_encoder_cache_field_to_output_field: typing.Optional[dict[str, str]] = None,
vae_output_cache_dir: typing.Optional[str] = None,
shuffle: bool = True,
) -> DataLoader:
"""Construct a DataLoader for an image-caption dataset for Flux.1-dev.
Args:
config (ImageCaptionFluxDataLoaderConfig): The dataset config.
batch_size (int): The DataLoader batch size.
text_encoder_output_cache_dir (str, optional): The directory where text encoder outputs are cached and should be
loaded from. If set, then the TokenizeTransform will not be applied.
vae_output_cache_dir (str, optional): The directory where VAE outputs are cached and should be loaded from. If
set, then the image augmentation transforms will be skipped, and the image will not be copied to VRAM.
shuffle (bool, optional): Whether to shuffle the dataset order.
Returns:
DataLoader
"""
if isinstance(config.dataset, HFHubImageCaptionDatasetConfig):
base_dataset = build_hf_hub_image_caption_dataset(config.dataset)
elif isinstance(config.dataset, ImageCaptionJsonlDatasetConfig):
base_dataset = build_image_caption_jsonl_dataset(config.dataset)
elif isinstance(config.dataset, ImageCaptionDirDatasetConfig):
base_dataset = build_image_caption_dir_dataset(config.dataset)
else:
raise ValueError(f"Unexpected dataset config type: '{type(config.dataset)}'.")
# Initialize either the fixed target resolution or aspect ratio buckets.
if config.aspect_ratio_buckets is None:
aspect_ratio_bucket_manager = None
batch_sampler = None
else:
aspect_ratio_bucket_manager = build_aspect_ratio_bucket_manager(config=config.aspect_ratio_buckets)
# TODO(ryand): Drill-down the seed parameter rather than hard-coding to 0 here.
batch_sampler = AspectRatioBucketBatchSampler.from_image_sizes(
bucket_manager=aspect_ratio_bucket_manager,
image_sizes=base_dataset.get_image_dimensions(),
batch_size=batch_size,
shuffle=shuffle,
seed=0,
)
all_transforms = []
if config.caption_prefix is not None:
all_transforms.append(CaptionPrefixTransform(caption_field_name="caption", prefix=config.caption_prefix + " "))
if vae_output_cache_dir is None:
image_field_names = ["image"]
if use_masks:
image_field_names.append("mask")
else:
all_transforms.append(DropFieldTransform("mask"))
all_transforms.append(
FluxImageTransform(
image_field_names=image_field_names,
fields_to_normalize_to_range_minus_one_to_one=["image"],
resolution=config.resolution,
aspect_ratio_bucket_manager=aspect_ratio_bucket_manager,
center_crop=config.center_crop,
random_flip=config.random_flip,
)
)
else:
# We drop the image to avoid having to either convert from PIL, or handle PIL batch collation.
all_transforms.append(DropFieldTransform("image"))
all_transforms.append(DropFieldTransform("mask"))
vae_cache = TensorDiskCache(vae_output_cache_dir)
cache_field_to_output_field = {
"vae_output": "vae_output",
"original_size_hw": "original_size_hw",
"crop_top_left_yx": "crop_top_left_yx",
}
if use_masks:
cache_field_to_output_field["mask"] = "mask"
all_transforms.append(
LoadCacheTransform(
cache=vae_cache,
cache_key_field="id",
cache_field_to_output_field=cache_field_to_output_field,
)
)
if text_encoder_output_cache_dir is not None:
assert text_encoder_cache_field_to_output_field is not None
text_encoder_cache = TensorDiskCache(text_encoder_output_cache_dir)
all_transforms.append(
LoadCacheTransform(
cache=text_encoder_cache,
cache_key_field="id",
cache_field_to_output_field=text_encoder_cache_field_to_output_field,
)
)
dataset = TransformDataset(base_dataset, all_transforms)
if batch_sampler is None:
return DataLoader(
dataset,
shuffle=shuffle,
collate_fn=flux_image_caption_collate_fn,
batch_size=batch_size,
num_workers=config.dataloader_num_workers,
)
else:
return DataLoader(
dataset,
batch_sampler=batch_sampler,
collate_fn=flux_image_caption_collate_fn,
num_workers=config.dataloader_num_workers,
)
================================================
FILE: src/invoke_training/_shared/data/data_loaders/image_caption_sd_dataloader.py
================================================
import typing
import torch
from torch.utils.data import DataLoader
from invoke_training._shared.data.datasets.build_dataset import (
build_hf_hub_image_caption_dataset,
build_image_caption_dir_dataset,
build_image_caption_jsonl_dataset,
)
from invoke_training._shared.data.datasets.transform_dataset import TransformDataset
from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import AspectRatioBucketBatchSampler
from invoke_training._shared.data.transforms.caption_prefix_transform import CaptionPrefixTransform
from invoke_training._shared.data.transforms.drop_field_transform import DropFieldTransform
from invoke_training._shared.data.transforms.load_cache_transform import LoadCacheTransform
from invoke_training._shared.data.transforms.sd_image_transform import SDImageTransform
from invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache
from invoke_training._shared.data.utils.aspect_ratio_bucket_manager import AspectRatioBucketManager
from invoke_training.config.data.data_loader_config import AspectRatioBucketConfig, ImageCaptionSDDataLoaderConfig
from invoke_training.config.data.dataset_config import (
HFHubImageCaptionDatasetConfig,
ImageCaptionDirDatasetConfig,
ImageCaptionJsonlDatasetConfig,
)
def sd_image_caption_collate_fn(examples):
"""A batch collation function for the image-caption SDXL data loader."""
out_examples = {
"id": [example["id"] for example in examples],
}
if "image" in examples[0]:
out_examples["image"] = torch.stack([example["image"] for example in examples])
if "original_size_hw" in examples[0]:
out_examples["original_size_hw"] = [example["original_size_hw"] for example in examples]
if "crop_top_left_yx" in examples[0]:
out_examples["crop_top_left_yx"] = [example["crop_top_left_yx"] for example in examples]
if "caption" in examples[0]:
out_examples["caption"] = [example["caption"] for example in examples]
if "loss_weight" in examples[0]:
out_examples["loss_weight"] = torch.tensor([example["loss_weight"] for example in examples])
if "prompt_embeds" in examples[0]:
out_examples["prompt_embeds"] = torch.stack([example["prompt_embeds"] for example in examples])
out_examples["pooled_prompt_embeds"] = torch.stack([example["pooled_prompt_embeds"] for example in examples])
if "text_encoder_output" in examples[0]:
out_examples["text_encoder_output"] = torch.stack([example["text_encoder_output"] for example in examples])
if "vae_output" in examples[0]:
out_examples["vae_output"] = torch.stack([example["vae_output"] for example in examples])
if "mask" in examples[0]:
out_examples["mask"] = torch.stack([example["mask"] for example in examples])
return out_examples
def build_aspect_ratio_bucket_manager(config: AspectRatioBucketConfig):
return AspectRatioBucketManager.from_constraints(
target_resolution=config.target_resolution,
start_dim=config.start_dim,
end_dim=config.end_dim,
divisible_by=config.divisible_by,
)
def build_image_caption_sd_dataloader( # noqa: C901
config: ImageCaptionSDDataLoaderConfig,
batch_size: int,
use_masks: bool = False,
text_encoder_output_cache_dir: typing.Optional[str] = None,
text_encoder_cache_field_to_output_field: typing.Optional[dict[str, str]] = None,
vae_output_cache_dir: typing.Optional[str] = None,
shuffle: bool = True,
) -> DataLoader:
"""Construct a DataLoader for an image-caption dataset for Stable Diffusion XL.
Args:
config (ImageCaptionSDDataLoaderConfig): The dataset config.
batch_size (int): The DataLoader batch size.
text_encoder_output_cache_dir (str, optional): The directory where text encoder outputs are cached and should be
loaded from. If set, then the TokenizeTransform will not be applied.
vae_output_cache_dir (str, optional): The directory where VAE outputs are cached and should be loaded from. If
set, then the image augmentation transforms will be skipped, and the image will not be copied to VRAM.
shuffle (bool, optional): Whether to shuffle the dataset order.
Returns:
DataLoader
"""
if isinstance(config.dataset, HFHubImageCaptionDatasetConfig):
base_dataset = build_hf_hub_image_caption_dataset(config.dataset)
elif isinstance(config.dataset, ImageCaptionJsonlDatasetConfig):
base_dataset = build_image_caption_jsonl_dataset(config.dataset)
elif isinstance(config.dataset, ImageCaptionDirDatasetConfig):
base_dataset = build_image_caption_dir_dataset(config.dataset)
else:
raise ValueError(f"Unexpected dataset config type: '{type(config.dataset)}'.")
# Initialize either the fixed target resolution or aspect ratio buckets.
if config.aspect_ratio_buckets is None:
target_resolution = config.resolution
aspect_ratio_bucket_manager = None
batch_sampler = None
else:
target_resolution = None
aspect_ratio_bucket_manager = build_aspect_ratio_bucket_manager(config=config.aspect_ratio_buckets)
# TODO(ryand): Drill-down the seed parameter rather than hard-coding to 0 here.
batch_sampler = AspectRatioBucketBatchSampler.from_image_sizes(
bucket_manager=aspect_ratio_bucket_manager,
image_sizes=base_dataset.get_image_dimensions(),
batch_size=batch_size,
shuffle=shuffle,
seed=0,
)
all_transforms = []
if config.caption_prefix is not None:
all_transforms.append(CaptionPrefixTransform(caption_field_name="caption", prefix=config.caption_prefix + " "))
if vae_output_cache_dir is None:
image_field_names = ["image"]
if use_masks:
image_field_names.append("mask")
else:
all_transforms.append(DropFieldTransform("mask"))
all_transforms.append(
SDImageTransform(
image_field_names=image_field_names,
fields_to_normalize_to_range_minus_one_to_one=["image"],
resolution=target_resolution,
aspect_ratio_bucket_manager=aspect_ratio_bucket_manager,
center_crop=config.center_crop,
random_flip=config.random_flip,
)
)
else:
# We drop the image to avoid having to either convert from PIL, or handle PIL batch collation.
all_transforms.append(DropFieldTransform("image"))
all_transforms.append(DropFieldTransform("mask"))
vae_cache = TensorDiskCache(vae_output_cache_dir)
cache_field_to_output_field = {
"vae_output": "vae_output",
"original_size_hw": "original_size_hw",
"crop_top_left_yx": "crop_top_left_yx",
}
if use_masks:
cache_field_to_output_field["mask"] = "mask"
all_transforms.append(
LoadCacheTransform(
cache=vae_cache,
cache_key_field="id",
cache_field_to_output_field=cache_field_to_output_field,
)
)
if text_encoder_output_cache_dir is not None:
assert text_encoder_cache_field_to_output_field is not None
text_encoder_cache = TensorDiskCache(text_encoder_output_cache_dir)
all_transforms.append(
LoadCacheTransform(
cache=text_encoder_cache,
cache_key_field="id",
cache_field_to_output_field=text_encoder_cache_field_to_output_field,
)
)
dataset = TransformDataset(base_dataset, all_transforms)
if batch_sampler is None:
return DataLoader(
dataset,
shuffle=shuffle,
collate_fn=sd_image_caption_collate_fn,
batch_size=batch_size,
num_workers=config.dataloader_num_workers,
)
else:
return DataLoader(
dataset,
batch_sampler=batch_sampler,
collate_fn=sd_image_caption_collate_fn,
num_workers=config.dataloader_num_workers,
)
================================================
FILE: src/invoke_training/_shared/data/data_loaders/image_pair_preference_sd_dataloader.py
================================================
import typing
import torch
from torch.utils.data import DataLoader
from invoke_training._shared.data.datasets.build_dataset import build_hf_image_pair_preference_dataset
from invoke_training._shared.data.datasets.image_pair_preference_dataset import ImagePairPreferenceDataset
from invoke_training._shared.data.datasets.transform_dataset import TransformDataset
from invoke_training._shared.data.transforms.load_cache_transform import LoadCacheTransform
from invoke_training._shared.data.transforms.sd_image_transform import SDImageTransform
from invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache
from invoke_training.pipelines._experimental.sd_dpo_lora.config import ImagePairPreferenceSDDataLoaderConfig
def sd_image_pair_preference_collate_fn(examples):
"""A batch collation function."""
stack_keys = {"image_0", "image_1", "prompt_embeds", "pooled_prompt_embeds", "text_encoder_output", "vae_output"}
list_keys = {
"id",
"original_size_hw_0",
"original_size_hw_1",
"crop_top_left_yx_0",
"crop_top_left_yx_1",
"prefer_0",
"prefer_1",
"caption",
}
unhandled_keys = set(examples[0].keys()) - (stack_keys | list_keys)
if len(unhandled_keys) > 0:
raise ValueError(f"The following keys are not handled by the collate function: {unhandled_keys}.")
out_examples = {}
# torch.stack(...)
for k in stack_keys:
if k in examples[0]:
out_examples[k] = torch.stack([example[k] for example in examples])
# Basic list.
for k in list_keys:
if k in examples[0]:
out_examples[k] = [example[k] for example in examples]
return out_examples
def build_image_pair_preference_sd_dataloader(
config: ImagePairPreferenceSDDataLoaderConfig,
batch_size: int,
text_encoder_output_cache_dir: typing.Optional[str] = None,
text_encoder_cache_field_to_output_field: typing.Optional[dict[str, str]] = None,
vae_output_cache_dir: typing.Optional[str] = None,
shuffle: bool = True,
) -> DataLoader:
"""Construct a DataLoader for an image-caption dataset for Stable Diffusion XL.
Args:
config (ImageCaptionSDDataLoaderConfig): The dataset config.
batch_size (int): The DataLoader batch size.
text_encoder_output_cache_dir (str, optional): The directory where text encoder outputs are cached and should be
loaded from. If set, then the TokenizeTransform will not be applied.
vae_output_cache_dir (str, optional): The directory where VAE outputs are cached and should be loaded from. If
set, then the image augmentation transforms will be skipped, and the image will not be copied to VRAM.
shuffle (bool, optional): Whether to shuffle the dataset order.
Returns:
DataLoader
"""
if config.dataset.type == "HF_HUB_IMAGE_PAIR_PREFERENCE_DATASET":
base_dataset = build_hf_image_pair_preference_dataset(config=config.dataset)
elif config.dataset.type == "IMAGE_PAIR_PREFERENCE_DATASET":
base_dataset = ImagePairPreferenceDataset(dataset_dir=config.dataset.dataset_dir)
else:
raise ValueError(f"Unexpected dataset config type: '{type(config.dataset)}'.")
target_resolution = config.resolution
all_transforms = []
if vae_output_cache_dir is None:
# TODO(ryand): Should I process both images in a single SDImageTransform so that they undergo the same
# transformations?
all_transforms.append(
SDImageTransform(
image_field_names=["image_0"],
fields_to_normalize_to_range_minus_one_to_one=["image_0"],
resolution=target_resolution,
aspect_ratio_bucket_manager=None,
center_crop=config.center_crop,
random_flip=config.random_flip,
orig_size_field_name="original_size_hw_0",
crop_field_name="crop_top_left_yx_0",
)
)
all_transforms.append(
SDImageTransform(
image_field_names=["image_1"],
fields_to_normalize_to_range_minus_one_to_one=["image_1"],
resolution=target_resolution,
aspect_ratio_bucket_manager=None,
center_crop=config.center_crop,
random_flip=config.random_flip,
orig_size_field_name="original_size_hw_1",
crop_field_name="crop_top_left_yx_1",
)
)
else:
raise NotImplementedError("VAE caching is not yet implemented.")
# vae_cache = TensorDiskCache(vae_output_cache_dir)
# all_transforms.append(
# LoadCacheTransform(
# cache=vae_cache,
# cache_key_field="id",
# cache_field_to_output_field={
# "vae_output": "vae_output",
# "original_size_hw": "original_size_hw",
# "crop_top_left_yx": "crop_top_left_yx",
# },
# )
# )
# # We drop the image to avoid having to either convert from PIL, or handle PIL batch collation.
# all_transforms.append(DropFieldTransform("image"))
if text_encoder_output_cache_dir is not None:
assert text_encoder_cache_field_to_output_field is not None
text_encoder_cache = TensorDiskCache(text_encoder_output_cache_dir)
all_transforms.append(
LoadCacheTransform(
cache=text_encoder_cache,
cache_key_field="id",
cache_field_to_output_field=text_encoder_cache_field_to_output_field,
)
)
dataset = TransformDataset(base_dataset, all_transforms)
return DataLoader(
dataset,
shuffle=shuffle,
collate_fn=sd_image_pair_preference_collate_fn,
batch_size=batch_size,
num_workers=config.dataloader_num_workers,
)
================================================
FILE: src/invoke_training/_shared/data/data_loaders/textual_inversion_sd_dataloader.py
================================================
from typing import Literal, Optional
from torch.utils.data import DataLoader
from invoke_training._shared.data.data_loaders.image_caption_sd_dataloader import (
build_aspect_ratio_bucket_manager,
sd_image_caption_collate_fn,
)
from invoke_training._shared.data.datasets.build_dataset import (
build_hf_hub_image_caption_dataset,
build_image_caption_dir_dataset,
build_image_caption_jsonl_dataset,
)
from invoke_training._shared.data.datasets.image_dir_dataset import ImageDirDataset
from invoke_training._shared.data.datasets.transform_dataset import TransformDataset
from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import AspectRatioBucketBatchSampler
from invoke_training._shared.data.transforms.concat_fields_transform import ConcatFieldsTransform
from invoke_training._shared.data.transforms.drop_field_transform import DropFieldTransform
from invoke_training._shared.data.transforms.load_cache_transform import LoadCacheTransform
from invoke_training._shared.data.transforms.sd_image_transform import SDImageTransform
from invoke_training._shared.data.transforms.shuffle_caption_transform import ShuffleCaptionTransform
from invoke_training._shared.data.transforms.template_caption_transform import TemplateCaptionTransform
from invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache
from invoke_training.config.data.data_loader_config import TextualInversionSDDataLoaderConfig
from invoke_training.config.data.dataset_config import (
HFHubImageCaptionDatasetConfig,
ImageCaptionDirDatasetConfig,
ImageCaptionJsonlDatasetConfig,
ImageDirDatasetConfig,
)
def get_preset_ti_caption_templates(preset: Literal["object", "style"]) -> list[str]:
if preset == "object":
return [
"a photo of a {}",
"a rendering of a {}",
"a cropped photo of the {}",
"the photo of a {}",
"a photo of a clean {}",
"a photo of a dirty {}",
"a dark photo of the {}",
"a photo of my {}",
"a photo of the cool {}",
"a close-up photo of a {}",
"a bright photo of the {}",
"a cropped photo of a {}",
"a photo of the {}",
"a good photo of the {}",
"a photo of one {}",
"a close-up photo of the {}",
"a rendition of the {}",
"a photo of the clean {}",
"a rendition of a {}",
"a photo of a nice {}",
"a good photo of a {}",
"a photo of the nice {}",
"a photo of the small {}",
"a photo of the weird {}",
"a photo of the large {}",
"a photo of a cool {}",
"a photo of a small {}",
]
elif preset == "style":
return [
"a painting in the style of {}",
"a rendering in the style of {}",
"a cropped painting in the style of {}",
"the painting in the style of {}",
"a clean painting in the style of {}",
"a dirty painting in the style of {}",
"a dark painting in the style of {}",
"a picture in the style of {}",
"a cool painting in the style of {}",
"a close-up painting in the style of {}",
"a bright painting in the style of {}",
"a good painting in the style of {}",
"a close-up painting in the style of {}",
"a rendition in the style of {}",
"a nice painting in the style of {}",
"a small painting in the style of {}",
"a weird painting in the style of {}",
"a large painting in the style of {}",
"a photo in the style of {}",
"an image in the style of {}",
"a drawing in the style of {}",
"a sketch in the style of {}",
"a digital work in the style of {}",
"a digital rendering in the style of {}",
"a photograph in the style of {}",
"photography in the style of {}",
]
else:
raise ValueError(f"Unrecognized learnable property type: '{preset}'.")
def build_textual_inversion_sd_dataloader( # noqa: C901
config: TextualInversionSDDataLoaderConfig,
placeholder_token: str,
batch_size: int,
use_masks: bool = False,
vae_output_cache_dir: Optional[str] = None,
shuffle: bool = True,
) -> DataLoader:
"""Construct a DataLoader for a Textual Inversion dataset for Stable Diffusion.
Args:
config (TextualInversionSDDataLoaderConfig): The dataset config.
placeholder_token (str): The placeholder token being trained.
batch_size (int): The DataLoader batch size.
vae_output_cache_dir (str, optional): The directory where VAE outputs are cached and should be loaded from. If
set, then the image augmentation transforms will be skipped, and the image will not be copied to VRAM.
shuffle (bool, optional): Whether to shuffle the dataset order.
Returns:
DataLoader
"""
if isinstance(config.dataset, HFHubImageCaptionDatasetConfig):
base_dataset = build_hf_hub_image_caption_dataset(config.dataset)
elif isinstance(config.dataset, ImageCaptionJsonlDatasetConfig):
base_dataset = build_image_caption_jsonl_dataset(config.dataset)
elif isinstance(config.dataset, ImageCaptionDirDatasetConfig):
base_dataset = build_image_caption_dir_dataset(config.dataset)
elif isinstance(config.dataset, ImageDirDatasetConfig):
base_dataset = ImageDirDataset(
image_dir=config.dataset.dataset_dir, keep_in_memory=config.dataset.keep_in_memory
)
else:
raise ValueError(f"Unexpected dataset config type: '{type(config.dataset)}'.")
# Initialize either the fixed target resolution or aspect ratio buckets.
if config.aspect_ratio_buckets is None:
target_resolution = config.resolution
aspect_ratio_bucket_manager = None
batch_sampler = None
else:
target_resolution = None
aspect_ratio_bucket_manager = build_aspect_ratio_bucket_manager(config=config.aspect_ratio_buckets)
# TODO(ryand): Drill-down the seed parameter rather than hard-coding to 0 here.
batch_sampler = AspectRatioBucketBatchSampler.from_image_sizes(
bucket_manager=aspect_ratio_bucket_manager,
image_sizes=base_dataset.get_image_dimensions(),
batch_size=batch_size,
shuffle=shuffle,
seed=0,
)
if sum([config.caption_templates is not None, config.caption_preset is not None]) != 1:
raise ValueError("Either caption_templates or caption_preset must be set.")
if config.caption_templates is not None:
# Overwrites the caption field. Typically used with a ImageDirDataset that does not have captions.
caption_tf = TemplateCaptionTransform(
field_name="caption_prefix" if config.keep_original_captions else "caption",
placeholder_str=placeholder_token,
caption_templates=config.caption_templates,
)
elif config.caption_preset is not None:
# Overwrites the caption field. Typically used with a ImageDirDataset that does not have captions.
caption_tf = TemplateCaptionTransform(
field_name="caption_prefix" if config.keep_original_captions else "caption",
placeholder_str=placeholder_token,
caption_templates=get_preset_ti_caption_templates(config.caption_preset),
)
else:
raise ValueError("Either caption_templates or caption_preset must be set.")
all_transforms = [caption_tf]
if config.keep_original_captions:
# This will only work with a HFHubImageCaptionDataset or HFDirImageCaptionDataset that already has captions.
all_transforms.append(
ConcatFieldsTransform(
src_field_names=["caption_prefix", "caption"], dst_field_name="caption", separator=" "
)
)
if config.shuffle_caption_delimiter is not None:
all_transforms.append(ShuffleCaptionTransform(field_name="caption", delimiter=config.shuffle_caption_delimiter))
if vae_output_cache_dir is None:
image_field_names = ["image"]
if use_masks:
image_field_names.append("mask")
else:
all_transforms.append(DropFieldTransform("mask"))
all_transforms.append(
SDImageTransform(
image_field_names=image_field_names,
fields_to_normalize_to_range_minus_one_to_one=["image"],
resolution=target_resolution,
aspect_ratio_bucket_manager=aspect_ratio_bucket_manager,
center_crop=config.center_crop,
random_flip=config.random_flip,
)
)
else:
# We drop the image to avoid having to either convert from PIL, or handle PIL batch collation.
all_transforms.append(DropFieldTransform("image"))
all_transforms.append(DropFieldTransform("mask"))
vae_cache = TensorDiskCache(vae_output_cache_dir)
cache_field_to_output_field = {
"vae_output": "vae_output",
"original_size_hw": "original_size_hw",
"crop_top_left_yx": "crop_top_left_yx",
}
if use_masks:
cache_field_to_output_field["mask"] = "mask"
all_transforms.append(
LoadCacheTransform(
cache=vae_cache,
cache_key_field="id",
cache_field_to_output_field=cache_field_to_output_field,
)
)
dataset = TransformDataset(base_dataset, all_transforms)
if batch_sampler is None:
return DataLoader(
dataset,
shuffle=shuffle,
collate_fn=sd_image_caption_collate_fn,
batch_size=batch_size,
num_workers=config.dataloader_num_workers,
persistent_workers=config.dataloader_num_workers > 0,
)
else:
return DataLoader(
dataset,
batch_sampler=batch_sampler,
collate_fn=sd_image_caption_collate_fn,
num_workers=config.dataloader_num_workers,
persistent_workers=config.dataloader_num_workers > 0,
)
================================================
FILE: src/invoke_training/_shared/data/datasets/__init__.py
================================================
================================================
FILE: src/invoke_training/_shared/data/datasets/build_dataset.py
================================================
from datasets import VerificationMode
from invoke_training._shared.data.datasets.hf_image_caption_dataset import HFImageCaptionDataset
from invoke_training._shared.data.datasets.hf_image_pair_preference_dataset import HFImagePairPreferenceDataset
from invoke_training._shared.data.datasets.image_caption_dir_dataset import ImageCaptionDirDataset
from invoke_training._shared.data.datasets.image_caption_jsonl_dataset import ImageCaptionJsonlDataset
from invoke_training.config.data.dataset_config import (
HFHubImageCaptionDatasetConfig,
ImageCaptionDirDatasetConfig,
ImageCaptionJsonlDatasetConfig,
)
from invoke_training.pipelines._experimental.sd_dpo_lora.config import HFHubImagePairPreferenceDatasetConfig
def build_hf_hub_image_caption_dataset(config: HFHubImageCaptionDatasetConfig) -> HFImageCaptionDataset:
return HFImageCaptionDataset.from_hub(
dataset_name=config.dataset_name,
hf_load_dataset_kwargs={
"name": config.dataset_config_name,
"cache_dir": config.hf_cache_dir,
},
image_column=config.image_column,
caption_column=config.caption_column,
)
def build_image_caption_jsonl_dataset(config: ImageCaptionJsonlDatasetConfig) -> HFImageCaptionDataset:
return ImageCaptionJsonlDataset(
jsonl_path=config.jsonl_path,
image_column=config.image_column,
caption_column=config.caption_column,
keep_in_memory=config.keep_in_memory,
)
def build_image_caption_dir_dataset(config: ImageCaptionDirDatasetConfig) -> ImageCaptionDirDataset:
return ImageCaptionDirDataset(
dataset_dir=config.dataset_dir,
keep_in_memory=config.keep_in_memory,
)
def build_hf_image_pair_preference_dataset(
config: HFHubImagePairPreferenceDatasetConfig,
) -> HFImagePairPreferenceDataset:
# HACK(ryand): This is currently hard-coded to just download a small slice of the very large
# 'yuvalkirstain/pickapic_v2' dataset.
return HFImagePairPreferenceDataset.from_hub(
"yuvalkirstain/pickapic_v2",
split="train",
hf_load_dataset_kwargs={
"data_files": {
# "validation_unique": "data/validation_unique-00000-of-00001-33ead111845fc9c4.parquet",
"train": [
"data/train-00000-of-00645-b66ac786bf6fb553.parquet",
"data/train-00001-of-00645-c7b349dd222d6515.parquet",
"data/train-00002-of-00645-e4f54d615a978deb.parquet",
"data/train-00003-of-00645-2b9d59bac8b433ff.parquet",
"data/train-00004-of-00645-e4964649dc0ea543.parquet",
"data/train-00005-of-00645-45e8efc0fe93f6e9.parquet",
]
},
# Disable checks so that it doesn't complain that I haven't downloaded the other splits.
"verification_mode": VerificationMode.NO_CHECKS,
},
)
================================================
FILE: src/invoke_training/_shared/data/datasets/hf_image_caption_dataset.py
================================================
import os
import typing
import datasets
import torch.utils.data
from PIL.Image import Image
from invoke_training._shared.data.utils.resolution import Resolution
class HFImageCaptionDataset(torch.utils.data.Dataset):
"""An image-caption dataset wrapper for Hugging Face datasets.
The wrapped HF dataset can be either from the HF hub, or in Imagefolder format
(https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder).
"""
def __init__(self, hf_dataset, image_column: str = "image", caption_column: str = "text"):
column_names = hf_dataset["train"].column_names
if image_column not in column_names:
raise ValueError(
f"The image_column='{image_column}' is not in the set of dataset column names: '{column_names}'."
)
if caption_column not in column_names:
raise ValueError(
f"The caption_column='{caption_column}' is not in the set of dataset column names: '{column_names}'."
)
self._image_column = image_column
def preprocess(examples):
images = [image.convert("RGB") for image in examples[image_column]]
return {
"image": images,
"caption": examples[caption_column],
}
self._hf_dataset = hf_dataset["train"].with_transform(preprocess)
@classmethod
def from_dir(
cls,
dataset_dir: str,
hf_load_dataset_kwargs: typing.Optional[dict[str, typing.Any]] = None,
image_column: str = "image",
caption_column: str = "text",
):
"""Initialize a HFImageCaptionDataset from a Hugging Face ImageFolder dataset directory
(https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder).
Args:
dataset_dir (str): The path to the dataset directory.
hf_load_dataset_kwargs (dict[str, typing.Any], optional): kwargs to forward to `datasets.load_dataset(...)`.
image_column (str, optional): The name of the image column in the dataset. Defaults to "image".
caption_column (str, optional): The name of the caption column in the dataset. Defaults to "text".
"""
hf_load_dataset_kwargs = hf_load_dataset_kwargs or {}
data_files = {"train": os.path.join(dataset_dir, "**")}
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
hf_dataset = datasets.load_dataset("imagefolder", data_files=data_files, **hf_load_dataset_kwargs)
return cls(hf_dataset=hf_dataset, image_column=image_column, caption_column=caption_column)
@classmethod
def from_hub(
cls,
dataset_name: str,
hf_load_dataset_kwargs: typing.Optional[dict[str, typing.Any]] = None,
image_column: str = "image",
caption_column: str = "text",
):
"""Initialize a HFImageCaptionDataset from a Hugging Face Hub dataset.
Args:
dataset_name (str): The HF Hub dataset name (a.k.a. path).
hf_load_dataset_kwargs (dict[str, typing.Any], optional): kwargs to forward to `datasets.load_dataset(...)`.
image_column (str, optional): The name of the image column in the dataset. Defaults to "image".
caption_column (str, optional): The name of the caption column in the dataset. Defaults to "text".
"""
hf_load_dataset_kwargs = hf_load_dataset_kwargs or {}
hf_dataset = datasets.load_dataset(dataset_name, **hf_load_dataset_kwargs)
return cls(hf_dataset=hf_dataset, image_column=image_column, caption_column=caption_column)
def get_image_dimensions(self) -> list[Resolution]:
"""Get the dimensions of all images in the dataset.
TODO(ryand): Re-think this approach. For large datasets (e.g. streaming from S3) it doesn't make sense to
calculate this dynamically every time.
"""
image_dims: list[Resolution] = []
for i in range(len(self._hf_dataset)):
example = self._hf_dataset[i]
image: Image = example[self._image_column]
image_dims.append(Resolution(image.height, image.width))
return image_dims
def __len__(self) -> int:
"""Get the dataset length.
Returns:
int: The number of image-caption pairs in the dataset.
"""
return len(self._hf_dataset)
def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:
"""Load the dataset example at index `idx`.
Raises:
IndexError: If `idx` is out of range.
Returns:
dict: A dataset example with 3 keys: "image", "caption", and "id".
The "image" key maps to a `PIL` image in RGB format.
The "caption" key maps to a string.
The "id" key is the example's index (often used for caching).
"""
example = self._hf_dataset[idx]
example["id"] = idx
return example
================================================
FILE: src/invoke_training/_shared/data/datasets/hf_image_pair_preference_dataset.py
================================================
import io
import typing
import datasets
import torch.utils.data
from PIL import Image
class HFImagePairPreferenceDataset(torch.utils.data.Dataset):
"""A wrapper for the Hugging Face hub "yuvalkirstain/pickapic_v2" dataset
(https://huggingface.co/datasets/yuvalkirstain/pickapic_v2).
Designed to be expanded in the future to other HF image pair preference datasets.
"""
def __init__(
self,
hf_dataset,
skip_no_preference=True,
split: str = "train",
image_0_column: str = "jpg_0",
label_0_column: str = "label_0",
image_1_column: str = "jpg_1",
label_1_column: str = "jpg_1",
caption_column: str = "caption",
):
"""
Args:
skip_no_preference (bool, optional): If True, skip image pairs without a preference.
"""
column_names = hf_dataset[split].column_names
for col_name in [image_0_column, label_0_column, image_1_column, label_1_column, caption_column]:
if col_name not in column_names:
raise ValueError(f"Column '{col_name}' is not in the set of dataset column names: '{column_names}'.")
eps = 0.0001
if skip_no_preference:
# Filter to only include pairs with a clear preference.
def filter(example: dict[str, typing.Any]) -> bool:
return abs(example["label_0"] - example["label_1"]) > eps
hf_dataset = hf_dataset.filter(filter)
def preprocess(examples):
image_0_list = [Image.open(io.BytesIO(image)).convert("RGB") for image in examples[image_0_column]]
image_1_list = [Image.open(io.BytesIO(image)).convert("RGB") for image in examples[image_1_column]]
image_0_is_better = []
image_1_is_better = []
for label_0, label_1 in zip(examples["label_0"], examples["label_1"]):
if (label_0 - label_1) > eps:
# Label 0 is better.
image_0_is_better.append(True)
image_1_is_better.append(False)
elif (label_1 - label_0) > eps:
# Label 1 is better.
image_0_is_better.append(False)
image_1_is_better.append(True)
else:
# Tie.
image_0_is_better.append(False)
image_1_is_better.append(False)
return {
"image_0": image_0_list,
"image_1": image_1_list,
"prefer_0": image_0_is_better,
"prefer_1": image_1_is_better,
"caption": examples[caption_column],
}
self._hf_dataset = hf_dataset[split].with_transform(preprocess)
@classmethod
def from_hub(
cls,
dataset_name: str,
skip_no_preference: bool = True,
split: str = "train",
hf_load_dataset_kwargs: typing.Optional[dict[str, typing.Any]] = None,
):
"""Initialize a HFImageCaptionDataset from a Hugging Face Hub dataset.
Args:
dataset_name (str): The HF Hub dataset name (a.k.a. path).
hf_load_dataset_kwargs (dict[str, typing.Any], optional): kwargs to forward to `datasets.load_dataset(...)`.
"""
if dataset_name != "yuvalkirstain/pickapic_v2":
raise NotImplementedError(
"The HFImagePairPreferenceDataset class likely won't work with datasets other than "
"'yuvalkirstain/pickapic_v2'."
)
hf_load_dataset_kwargs = hf_load_dataset_kwargs or {}
hf_dataset = datasets.load_dataset(dataset_name, **hf_load_dataset_kwargs)
return cls(hf_dataset=hf_dataset, skip_no_preference=skip_no_preference, split=split)
def __len__(self) -> int:
"""Get the dataset length.
Returns:
int: The number of image pairs in the dataset.
"""
return len(self._hf_dataset)
def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:
"""Load the dataset example at index `idx`.
Raises:
IndexError: If `idx` is out of range.
Returns:
dict: A dataset example with the following keys: ["id", "image_1", "caption_1", "image_2", "caption_2",
"prefer_1", "prefer_2"]
The image keys map to a `PIL` image in RGB format.
The caption keys map to strings.
The "id" key is the example's index (often used for caching).
"""
example = self._hf_dataset[idx]
example["id"] = idx
return example
================================================
FILE: src/invoke_training/_shared/data/datasets/image_caption_dir_dataset.py
================================================
import os
import typing
import torch.utils.data
from PIL import Image
from invoke_training._shared.data.utils.resolution import Resolution
class ImageCaptionDirDataset(torch.utils.data.Dataset):
"""A dataset that loads images and captions from a directory of image files and .txt files."""
def __init__(
self,
dataset_dir: str,
id_prefix: str = "",
image_extensions: typing.Optional[list[str]] = None,
caption_extension: str = ".txt",
keep_in_memory: bool = False,
):
"""Initialize an ImageDirDataset
Args:
image_dir (str): The directory to load images from.
id_prefix (str): A prefix added to the 'id' field in every example.
image_extensions (list[str], optional): The list of image file extensions to include in the dataset (not
case-sensitive). Defaults to [".jpg", ".jpeg", ".png"].
keep_in_memory (bool, optional): If True, keep all images loaded in memory. This improves performance for
datasets that are small enough to be kept in memory.
"""
super().__init__()
self._id_prefix = id_prefix
if image_extensions is None:
image_extensions = [".jpg", ".jpeg", ".png"]
image_extensions = [ext.lower() for ext in image_extensions]
# Determine the list of image paths to include in the dataset.
self._image_paths: list[str] = []
for image_file in os.listdir(dataset_dir):
image_path = os.path.join(dataset_dir, image_file)
if os.path.isfile(image_path) and os.path.splitext(image_path)[1].lower() in image_extensions:
self._image_paths.append(image_path)
self._image_paths.sort()
# Load captions from .txt files for each image.
self._captions: list[str] = []
missing_captions: list[str] = []
for image_path in self._image_paths:
caption_path = os.path.splitext(image_path)[0] + caption_extension
if os.path.isfile(caption_path):
with open(caption_path, "r") as f:
self._captions.append(f.read().strip())
else:
missing_captions.append(caption_path)
if len(missing_captions) > 0:
raise Exception(f"The following expected caption files are missing: {missing_captions}")
self._images = None
if keep_in_memory:
self._images = []
for image_path in self._image_paths:
self._images.append(self._load_image(image_path))
def _load_image(self, image_path: str) -> Image.Image:
# We call `convert("RGB")` to drop the alpha channel from RGBA images, or to repeat channels for greyscale
# images.
return Image.open(image_path).convert("RGB")
def get_image_dimensions(self) -> list[Resolution]:
"""Get the dimensions of all images in the dataset.
TODO(ryand): Re-think this approach. For large datasets (e.g. streaming from S3) it doesn't make sense to
calculate this dynamically every time.
"""
image_dims: list[Resolution] = []
for i in range(len(self._image_paths)):
image_path = self._image_paths[i]
image = Image.open(image_path)
image_dims.append(Resolution(image.height, image.width))
return image_dims
def __len__(self) -> int:
return len(self._image_paths)
def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:
image = self._images[idx] if self._images is not None else self._load_image(self._image_paths[idx])
return {"id": f"{self._id_prefix}{idx}", "image": image, "caption": self._captions[idx]}
================================================
FILE: src/invoke_training/_shared/data/datasets/image_caption_jsonl_dataset.py
================================================
import typing
from pathlib import Path
import torch.utils.data
from PIL import Image
from pydantic import BaseModel
from invoke_training._shared.data.utils.resolution import Resolution
from invoke_training._shared.utils.jsonl import load_jsonl, save_jsonl
IMAGE_COLUMN_DEFAULT = "image"
CAPTION_COLUMN_DEFAULT = "text"
MASK_COLUMN_DEFAULT = "mask"
class ImageCaptionExample(BaseModel):
image_path: str
mask_path: str | None = None
caption: str
class ImageCaptionJsonlDataset(torch.utils.data.Dataset):
"""A dataset that loads images and captions from a directory of image files and .txt files."""
def __init__(
self,
jsonl_path: Path | str,
image_column: str = IMAGE_COLUMN_DEFAULT,
caption_column: str = CAPTION_COLUMN_DEFAULT,
keep_in_memory: bool = False,
):
super().__init__()
self._jsonl_path = Path(jsonl_path)
self._image_column = image_column
self._caption_column = caption_column
data = load_jsonl(jsonl_path)
examples: list[ImageCaptionExample] = []
for d in data:
# Clear error messages here are helpful in the Gradio UI.
if image_column not in d:
raise ValueError(f"Column '{image_column}' not found in jsonl file '{jsonl_path}'.")
if caption_column not in d:
raise ValueError(f"Column '{caption_column}' not found in jsonl file '{jsonl_path}'.")
examples.append(
ImageCaptionExample(
image_path=d[image_column], mask_path=d.get(MASK_COLUMN_DEFAULT, None), caption=d[caption_column]
)
)
self.examples = examples
self._keep_in_memory = keep_in_memory
self._example_cache: dict[int, dict[str, typing.Any]] = {}
def save_jsonl(self):
data = []
for example in self.examples:
data.append(
{
self._image_column: example.image_path,
self._caption_column: example.caption,
MASK_COLUMN_DEFAULT: example.mask_path,
}
)
save_jsonl(data, self._jsonl_path)
def _get_image_path(self, idx: int) -> str:
image_path = self.examples[idx].image_path
image_path = Path(image_path)
# image_path could be either absolute, or relative to the jsonl file.
if not image_path.is_absolute():
image_path = self._jsonl_path.parent / image_path
return image_path
def _get_mask_path(self, idx: int) -> str:
mask_path = self.examples[idx].mask_path
mask_path = Path(mask_path)
# mask_path could be either absolute, or relative to the jsonl file.
if not mask_path.is_absolute():
mask_path = self._jsonl_path.parent / mask_path
return mask_path
def _load_image(self, image_path: str) -> Image.Image:
# We call `convert("RGB")` to drop the alpha channel from RGBA images, or to repeat channels for greyscale
# images.
return Image.open(image_path).convert("RGB")
def _load_mask(self, mask_path: str) -> Image.Image:
return Image.open(mask_path).convert("L")
def _load_example(self, idx: int) -> dict[str, typing.Any]:
example = {
"id": str(idx),
"image": self._load_image(self._get_image_path(idx)),
"caption": self.examples[idx].caption,
}
if self.examples[idx].mask_path:
example["mask"] = self._load_mask(self._get_mask_path(idx))
return example
def get_image_dimensions(self) -> list[Resolution]:
"""Get the dimensions of all images in the dataset.
TODO(ryand): Re-think this approach. For large datasets (e.g. streaming from S3) it doesn't make sense to
calculate this dynamically every time.
"""
image_dims: list[Resolution] = []
for i in range(len(self.examples)):
image = Image.open(self._get_image_path(i))
image_dims.append(Resolution(image.height, image.width))
return image_dims
def __len__(self) -> int:
return len(self.examples)
def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:
if self._keep_in_memory:
if idx not in self._example_cache:
self._example_cache[idx] = self._load_example(idx)
# Return a shallow copy of the example to prevent the caller from modifying the cached example.
# Shallow rather than deep, because we don't want to copy the image data.
return self._example_cache[idx].copy()
return self._load_example(idx)
================================================
FILE: src/invoke_training/_shared/data/datasets/image_dir_dataset.py
================================================
import os
import typing
import torch.utils.data
from PIL import Image
from invoke_training._shared.data.utils.resolution import Resolution
class ImageDirDataset(torch.utils.data.Dataset):
"""A dataset that loads image files from a directory."""
def __init__(
self,
image_dir: str,
id_prefix: str = "",
image_extensions: typing.Optional[list[str]] = None,
keep_in_memory: bool = False,
):
"""Initialize an ImageDirDataset
Args:
image_dir (str): The directory to load images from.
id_prefix (str): A prefix added to the 'id' field in every example.
image_extensions (list[str], optional): The list of image file extensions to include in the dataset (not
case-sensitive). Defaults to [".jpg", ".jpeg", ".png"].
keep_in_memory (bool, optional): If True, keep all images loaded in memory. This improves performance for
datasets that are small enough to be kept in memory.
"""
super().__init__()
self._id_prefix = id_prefix
if image_extensions is None:
image_extensions = [".jpg", ".jpeg", ".png"]
image_extensions = [ext.lower() for ext in image_extensions]
self._image_paths = []
for image_file in os.listdir(image_dir):
image_path = os.path.join(image_dir, image_file)
if os.path.isfile(image_path) and os.path.splitext(image_path)[1].lower() in image_extensions:
self._image_paths.append(image_path)
self._images = None
if keep_in_memory:
self._images = []
for image_path in self._image_paths:
self._images.append(self._load_image(image_path))
def _load_image(self, image_path: str) -> Image.Image:
# We call `convert("RGB")` to drop the alpha channel from RGBA images, or to repeat channels for greyscale
# images.
return Image.open(image_path).convert("RGB")
def get_image_dimensions(self) -> list[Resolution]:
"""Get the dimensions of all images in the dataset.
TODO(ryand): Re-think this approach. For large datasets (e.g. streaming from S3) it doesn't make sense to
calculate this dynamically every time.
"""
image_dims: list[Resolution] = []
for i in range(len(self._image_paths)):
image_path = self._image_paths[i]
image = Image.open(image_path)
image_dims.append(Resolution(image.height, image.width))
return image_dims
def __len__(self) -> int:
return len(self._image_paths)
def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:
image = self._images[idx] if self._images is not None else self._load_image(self._image_paths[idx])
return {"id": f"{self._id_prefix}{idx}", "image": image}
================================================
FILE: src/invoke_training/_shared/data/datasets/image_pair_preference_dataset.py
================================================
import os
import typing
from pathlib import Path
import torch.utils.data
from PIL import Image
from invoke_training._shared.utils.jsonl import load_jsonl, save_jsonl
class ImagePairPreferenceDataset(torch.utils.data.Dataset):
def __init__(self, dataset_dir: str):
super().__init__()
self._dataset_dir = dataset_dir
self._metadata = load_jsonl(Path(dataset_dir) / "metadata.jsonl")
@classmethod
def save_metadata(
cls, metadata: list[dict[str, typing.Any]], dataset_dir: str | Path, metadata_file: str = "metadata.jsonl"
) -> Path:
"""Load the dataset metadata from metadata.jsonl."""
metadata_path = Path(dataset_dir) / metadata_file
save_jsonl(metadata, metadata_path)
return metadata_path
def __len__(self) -> int:
return len(self._metadata)
def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]:
# We call `convert("RGB")` to drop the alpha channel from RGBA images, or to repeat channels for greyscale
# images.
example = self._metadata[idx]
image_0_path = os.path.join(self._dataset_dir, example["image_0"])
image_1_path = os.path.join(self._dataset_dir, example["image_1"])
return {
"id": str(idx),
"image_0": Image.open(image_0_path).convert("RGB"),
"image_1": Image.open(image_1_path).convert("RGB"),
"caption": example["prompt"],
"prefer_0": example["prefer_0"],
"prefer_1": example["prefer_1"],
}
================================================
FILE: src/invoke_training/_shared/data/datasets/transform_dataset.py
================================================
import typing
import torch.utils.data
# The data type expected to be produced by the base dataset and handled by transforms.
DataType = typing.Dict[str, typing.Any]
TransformType = typing.Callable[[DataType], DataType]
class TransformDataset(torch.utils.data.Dataset):
"""A Dataset that wraps a base dataset and applies callable transforms to its outputs."""
def __init__(self, base_dataset: torch.utils.data.Dataset, transforms: list[TransformType]) -> None:
super().__init__()
self._base_dataset = base_dataset
self._transforms = transforms
def __len__(self) -> int:
return len(self._base_dataset)
def __getitem__(self, idx: int) -> DataType:
example = self._base_dataset[idx]
for t in self._transforms:
example = t(example)
return example
================================================
FILE: src/invoke_training/_shared/data/samplers/__init__.py
================================================
================================================
FILE: src/invoke_training/_shared/data/samplers/aspect_ratio_bucket_batch_sampler.py
================================================
import copy
import logging
import math
import random
from typing import Iterator
from torch.utils.data import Sampler
from invoke_training._shared.data.utils.aspect_ratio_bucket_manager import AspectRatioBucketManager
from invoke_training._shared.data.utils.resolution import Resolution
AspectRatioBuckets = dict[Resolution, list[int]]
class AspectRatioBucketBatchSampler(Sampler[list[int]]):
"""A batch sampler that adheres to aspect ratio buckets."""
def __init__(
self,
buckets: AspectRatioBuckets,
batch_size: int,
shuffle: bool = False,
seed: int | None = None,
) -> None:
"""Initialize AspectRatioBucketBatchSampler.
For most use cases, initialize via AspectRatioBucketBatchSampler.from_image_sizes(...).
"""
self._buckets = buckets
self._batch_size = batch_size
self._shuffle = shuffle
self._random = random.Random(seed)
def __str__(self) -> str:
buckets = self.get_buckets()
bucket_resolutions = sorted(list(buckets.keys()))
s = ""
for bucket_resolution in bucket_resolutions:
bucket_images = buckets[bucket_resolution]
s += f" {bucket_resolution.to_tuple()}: {len(bucket_images)}\n"
return s
@classmethod
def from_image_sizes(
cls,
bucket_manager: AspectRatioBucketManager,
image_sizes: list[Resolution],
batch_size: int,
shuffle: bool = False,
seed: int | None = None,
):
"""Initialize from an AspectRatioBucketManager and the list of dataset image resolutions."""
buckets = cls._build_bucket_to_index_map(bucket_manager, image_sizes)
return cls(buckets=buckets, batch_size=batch_size, shuffle=shuffle, seed=seed)
@classmethod
def _build_bucket_to_index_map(
cls,
bucket_manager: AspectRatioBucketManager,
image_sizes: list[Resolution],
) -> AspectRatioBuckets:
bucket_to_indexes: AspectRatioBuckets = dict()
for bucket_resolution in bucket_manager.buckets:
bucket_to_indexes[bucket_resolution] = []
for index, image_size in enumerate(image_sizes):
aspect_ratio_bucket = bucket_manager.get_aspect_ratio_bucket(image_size)
bucket_to_indexes[aspect_ratio_bucket].append(index)
return bucket_to_indexes
def get_buckets(self) -> AspectRatioBuckets:
return copy.deepcopy(self._buckets)
def __iter__(self) -> Iterator[list[int]]:
batches: list[list[int]] = []
# TODO(ryand): If self._shuffle == False, should we still shuffle just with a fixed seed every time? If we
# don't shuffle at all then all of the batches from a bucket will be grouped together. If there's a correlation
# between aspect ratio and image content in a dataset, this could result in unevenly distributed image content
# over the dataset.
for bucket_resolution in sorted(list(self._buckets.keys())):
ordered_bucket_images = self._buckets[bucket_resolution].copy()
if self._shuffle:
# Shuffle the images within a bucket.
self._random.shuffle(ordered_bucket_images)
# Prepare batches for a single bucket.
batch_start = 0
while batch_start < len(ordered_bucket_images):
batch_end = min(batch_start + self._batch_size, len(ordered_bucket_images))
batches.append(ordered_bucket_images[batch_start:batch_end])
batch_start += self._batch_size
if self._shuffle:
# We've already shuffled the images within each bucket, now we shuffle the batches.
self._random.shuffle(batches)
yield from batches
def __len__(self) -> int:
num_batches = 0
for bucket_images in self._buckets.values():
num_batches += math.ceil(len(bucket_images) / self._batch_size)
return num_batches
def log_aspect_ratio_buckets(logger: logging.Logger, batch_sampler: AspectRatioBucketBatchSampler):
"""Utility function for logging the aspect ratio buckets."""
if not isinstance(batch_sampler, AspectRatioBucketBatchSampler):
return
log = "Aspect Ratio Buckets:\n"
log += str(batch_sampler)
logger.info(log)
================================================
FILE: src/invoke_training/_shared/data/samplers/batch_offset_sampler.py
================================================
import typing
from torch.utils.data import Sampler
class BatchOffsetSampler(Sampler[int]):
"""A sampler that wraps a batch sampler and applies an offset to all returned batch elements."""
def __init__(self, sampler: Sampler[int], offset: int):
self._sampler = sampler
self._offset = offset
def __iter__(self) -> typing.Iterator[int]:
for batch in self._sampler:
offset_batch = [x + self._offset for x in batch]
yield offset_batch
def __len__(self) -> int:
return len(self._sampler)
================================================
FILE: src/invoke_training/_shared/data/samplers/concat_sampler.py
================================================
import itertools
import typing
from torch.utils.data import Sampler
T_co = typing.TypeVar("T_co", covariant=True)
class ConcatSampler(Sampler[T_co]):
"""A meta-Sampler that concatenates multiple samplers.
Example:
sampler 1: ABCD
sampler 2: EFG
sampler 3: HIJKLM
ConcatSampler: ABCDEFGHIJKLM
"""
def __init__(self, samplers: list[Sampler[T_co] | typing.Iterable[T_co]]) -> None:
self._samplers = samplers
def __iter__(self) -> typing.Iterator[T_co]:
return itertools.chain(*self._samplers)
def __len__(self) -> int:
return sum([len(s) for s in self._samplers])
================================================
FILE: src/invoke_training/_shared/data/samplers/interleaved_sampler.py
================================================
import typing
from torch.utils.data import Sampler
T_co = typing.TypeVar("T_co", covariant=True)
class InterleavedSampler(Sampler[T_co]):
"""A meta-Sampler that interleaves multiple samplers.
The length of this sampler is based on the length of the shortest input sampler. All samplers will contribute the
same number of samples to the interleaved output.
Example:
sampler 1: ABCD
sampler 2: EFG
sampler 3: HIJKLM
interleaved sampler: AEHBFICGJ
"""
def __init__(self, samplers: list[Sampler[T_co] | typing.Iterable[T_co]]) -> None:
self._samplers = samplers
self._min_sampler_len = min([len(s) for s in self._samplers])
def __iter__(self) -> typing.Iterator[T_co]:
sampler_iters = [iter(s) for s in self._samplers]
while True:
samples = []
for sampler_iter in sampler_iters:
try:
samples.append(next(sampler_iter))
except StopIteration:
# The end of the shortest sampler has been reached.
return
yield from samples
def __len__(self) -> int:
return self._min_sampler_len * len(self._samplers)
================================================
FILE: src/invoke_training/_shared/data/samplers/offset_sampler.py
================================================
import typing
from torch.utils.data import Sampler
class OffsetSampler(Sampler[int]):
"""A sampler that wraps another sampler and applies an offset to all returned values."""
def __init__(self, sampler: Sampler[int], offset: int):
self._sampler = sampler
self._offset = offset
def __iter__(self) -> typing.Iterator[int]:
for idx in self._sampler:
yield idx + self._offset
def __len__(self) -> int:
return len(self._sampler)
================================================
FILE: src/invoke_training/_shared/data/transforms/__init__.py
================================================
================================================
FILE: src/invoke_training/_shared/data/transforms/caption_prefix_transform.py
================================================
import typing
class CaptionPrefixTransform:
"""A transform that adds a prefix to all example captions."""
def __init__(self, caption_field_name: str, prefix: str):
self._caption_field_name = caption_field_name
self._prefix = prefix
def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
data[self._caption_field_name] = self._prefix + data[self._caption_field_name]
return data
================================================
FILE: src/invoke_training/_shared/data/transforms/concat_fields_transform.py
================================================
import typing
class ConcatFieldsTransform:
"""A transform that concatenate multiple string fields."""
def __init__(self, src_field_names: list[str], dst_field_name: str, separator: str = " "):
self._src_field_names = src_field_names
self._dst_field_name = dst_field_name
self._separator = separator
def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
result = self._separator.join([data[field_name] for field_name in self._src_field_names])
data[self._dst_field_name] = result
return data
================================================
FILE: src/invoke_training/_shared/data/transforms/constant_field_transform.py
================================================
import typing
class ConstantFieldTransform:
"""A simple transform that adds a constant field to every example."""
def __init__(self, field_name: str, field_value: typing.Any):
self._field_name = field_name
self._field_value = field_value
def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
data[self._field_name] = self._field_value
return data
================================================
FILE: src/invoke_training/_shared/data/transforms/drop_field_transform.py
================================================
import typing
class DropFieldTransform:
"""A simple transform that drops a field from an example."""
def __init__(self, field_to_drop: str):
self._field_to_drop = field_to_drop
def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
if self._field_to_drop in data:
del data[self._field_to_drop]
return data
================================================
FILE: src/invoke_training/_shared/data/transforms/flux_image_transform.py
================================================
import typing
from torchvision import transforms
from torchvision.transforms.functional import crop
from invoke_training._shared.data.utils.aspect_ratio_bucket_manager import AspectRatioBucketManager, Resolution
from invoke_training._shared.data.utils.resize import resize_to_cover
class FluxImageTransform:
"""A transform that prepares and augments images for Flux.1-dev training."""
def __init__(
self,
image_field_names: list[str],
fields_to_normalize_to_range_minus_one_to_one: list[str],
resolution: int | None = 512,
aspect_ratio_bucket_manager: AspectRatioBucketManager | None = None,
random_flip: bool = True,
center_crop: bool = True,
):
"""Initialize FluxImageTransform.
Args:
image_field_names (list[str]): The field names of the images to be transformed.
resolution (int): The image resolution that will be produced. One of `resolution` and
`aspect_ratio_bucket_manager` should be non-None.
aspect_ratio_bucket_manager (AspectRatioBucketManager): The AspectRatioBucketManager used to determine the
target resolution for each image. One of `resolution` and `aspect_ratio_bucket_manager` should be
non-None.
center_crop (bool, optional): If True, crop to the center of the image to achieve the target resolution. If
False, crop at a random location.
random_flip (bool, optional): Whether to apply a random horizontal flip to the images.
"""
self.image_field_names = image_field_names
self.fields_to_normalize_to_range_minus_one_to_one = fields_to_normalize_to_range_minus_one_to_one
self.resolution = resolution
self.aspect_ratio_bucket_manager = aspect_ratio_bucket_manager
self.random_flip = random_flip
self.center_crop = center_crop
def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]: # noqa: C901
image_fields: dict = {}
for field_name in self.image_field_names:
image_fields[field_name] = data[field_name]
# Get the first image to determine original size and resolution
first_image = next(iter(image_fields.values()))
original_size_hw = (first_image.height, first_image.width)
for field_name, image in image_fields.items():
# Determine the target image resolution.
if self.resolution is not None:
resolution = self.resolution
resolution_obj = Resolution(resolution, resolution)
else:
resolution_obj = self.aspect_ratio_bucket_manager.get_aspect_ratio_bucket(
Resolution.parse(original_size_hw)
)
image = resize_to_cover(image, resolution_obj)
# Apply cropping and record top left crop position
if self.center_crop:
top_left_y = max(0, (image.height - resolution_obj.height) // 2)
top_left_x = max(0, (image.width - resolution_obj.width) // 2)
image = transforms.CenterCrop(resolution_obj.to_tuple())(image)
else:
crop_transform = transforms.RandomCrop(resolution_obj.to_tuple())
top_left_y, top_left_x, h, w = crop_transform.get_params(image, resolution_obj.to_tuple())
image = crop(image, top_left_y, top_left_x, resolution_obj.height, resolution_obj.width)
# Apply random flip and update top left crop position accordingly
if self.random_flip:
# TODO: Use a seed for repeatable results
import random
if random.random() < 0.5:
top_left_x = original_size_hw[1] - image.width - top_left_x
image = transforms.RandomHorizontalFlip(p=1.0)(image)
image = transforms.ToTensor()(image)
if field_name in self.fields_to_normalize_to_range_minus_one_to_one:
image_fields[field_name] = transforms.Normalize([0.5], [0.5])(image)
else:
image_fields[field_name] = image
# Store the processed images and metadata
for field_name, image in image_fields.items():
data[field_name] = image
# Add metadata fields expected by VAE caching
data["original_size_hw"] = original_size_hw
data["crop_top_left_yx"] = (top_left_y, top_left_x)
return data
================================================
FILE: src/invoke_training/_shared/data/transforms/load_cache_transform.py
================================================
import typing
from invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache
class LoadCacheTransform:
"""A transform that loads data from a TensorDiskCache."""
def __init__(
self, cache: TensorDiskCache, cache_key_field: str, cache_field_to_output_field: typing.Dict[str, str]
):
"""Initialize LoadCacheTransform.
Args:
cache (TensorDiskCache): The cache to load from.
cache_key_field (str): The name of the field to use as the cache key.
cache_field_to_output_field (typing.Dict[str, str]): A map of field names in the cached data to the field
names where they should be inserted in the example data.
"""
self._cache = cache
self._cache_key_field = cache_key_field
self._cache_field_to_output_field = cache_field_to_output_field
def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
key = data[self._cache_key_field]
cache_data = self._cache.load(key)
for src, dst in self._cache_field_to_output_field.items():
data[dst] = cache_data[src]
return data
================================================
FILE: src/invoke_training/_shared/data/transforms/sd_image_transform.py
================================================
import random
import typing
from torchvision import transforms
from torchvision.transforms.functional import crop
from invoke_training._shared.data.utils.aspect_ratio_bucket_manager import AspectRatioBucketManager, Resolution
from invoke_training._shared.data.utils.resize import resize_to_cover
class SDImageTransform:
"""A transform that prepares and augments images for Stable Diffusion training."""
def __init__(
self,
image_field_names: list[str],
fields_to_normalize_to_range_minus_one_to_one: list[str],
resolution: int | tuple[int, int] | Resolution | None,
aspect_ratio_bucket_manager: AspectRatioBucketManager | None = None,
center_crop: bool = True,
random_flip: bool = False,
orig_size_field_name: str = "original_size_hw",
crop_field_name: str = "crop_top_left_yx",
):
"""Initialize SDImageTransform.
Args:
image_field_names (list[str]): The field names of the images to be transformed.
resolution (Resolution): The image resolution that will be produced. One of `resolution` and
`aspect_ratio_bucket_manager` should be non-None.
aspect_ratio_bucket_manager (AspectRatioBucketManager): The AspectRatioBucketManager used to determine the
target resolution for each image. One of `resolution` and `aspect_ratio_bucket_manager` should be
non-None.
center_crop (bool, optional): If True, crop to the center of the image to achieve the target resolution. If
False, crop at a random location.
random_flip (bool, optional): Whether to apply a random horizontal flip to the images.
"""
self._image_field_names = image_field_names
self._fields_to_normalize_to_range_minus_one_to_one = fields_to_normalize_to_range_minus_one_to_one
if resolution is not None and aspect_ratio_bucket_manager is not None:
raise ValueError("Only one of `resolution` or `aspect_ratio_bucket_manager` should be set.")
if resolution is None and aspect_ratio_bucket_manager is None:
raise ValueError("One of `resolution` or `aspect_ratio_bucket_manager` must be set.")
self._resolution = Resolution.parse(resolution) if resolution is not None else None
self._aspect_ratio_bucket_manager = aspect_ratio_bucket_manager
self._center_crop_enabled = center_crop
self._random_flip_enabled = random_flip
self._flip_transform = transforms.RandomHorizontalFlip(p=1.0)
self._to_tensor_transform = transforms.ToTensor()
# Convert pixel values from range [0, 1.0] to range [-1.0, 1.0].
# Normalize applies the following transform: out = (in - 0.5) / 0.5
self._normalize_image_transform = transforms.Normalize([0.5], [0.5])
self._orig_size_field_name = orig_size_field_name
self._crop_field_name = crop_field_name
def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]: # noqa: C901
# This SDXL image pre-processing logic is adapted from:
# https://github.com/huggingface/diffusers/blob/7b07f9812a58bfa96c06ed8ffe9e6b584286e2fd/examples/text_to_image/train_text_to_image_lora_sdxl.py#L850-L873
image_fields: dict = {}
for field_name in self._image_field_names:
image_fields[field_name] = data[field_name]
sizes = [image.size for image in image_fields.values()]
# All images should have the same size.
assert all(size == sizes[0] for size in sizes)
# Helper function to access the first image, which is sometimes used to infer the shape of all images.
def get_first_image():
return next(iter(image_fields.values()))
original_size_hw = (get_first_image().height, get_first_image().width)
# Determine the target image resolution.
if self._resolution is not None:
resolution = self._resolution
else:
resolution = self._aspect_ratio_bucket_manager.get_aspect_ratio_bucket(Resolution.parse(original_size_hw))
# Resize to cover the target resolution while preserving aspect ratio.
for field_name, image in image_fields.items():
image_fields[field_name] = resize_to_cover(image, resolution)
# Apply cropping, and record top left crop position.
if self._center_crop_enabled:
top_left_y = max(0, (get_first_image().height - resolution.height) // 2)
top_left_x = max(0, (get_first_image().width - resolution.width) // 2)
else:
crop_transform = transforms.RandomCrop(resolution.to_tuple())
top_left_y, top_left_x, h, w = crop_transform.get_params(get_first_image(), resolution.to_tuple())
for field_name, image in image_fields.items():
image_fields[field_name] = crop(image, top_left_y, top_left_x, resolution.height, resolution.width)
# Apply random flip and update top left crop position accordingly.
# TODO(ryand): Use a seed for repeatable results.
if self._random_flip_enabled and random.random() < 0.5:
top_left_x = original_size_hw[1] - get_first_image().width - top_left_x
for field_name, image in image_fields.items():
image_fields[field_name] = self._flip_transform(image)
crop_top_left_yx = (top_left_y, top_left_x)
# Convert to Tensors.
for field_name, image in image_fields.items():
image_fields[field_name] = self._to_tensor_transform(image)
# Normalize to range [-1.0, 1.0].
# HACK(ryand): We should find a better way to determine the normalization range of each image field.
for field_name, image in image_fields.items():
if field_name in self._fields_to_normalize_to_range_minus_one_to_one:
image_fields[field_name] = self._normalize_image_transform(image)
data[self._orig_size_field_name] = original_size_hw
data[self._crop_field_name] = crop_top_left_yx
for field_name, image in image_fields.items():
data[field_name] = image
return data
================================================
FILE: src/invoke_training/_shared/data/transforms/shuffle_caption_transform.py
================================================
import typing
import numpy as np
class ShuffleCaptionTransform:
"""A transform that applies shuffle transformations to character-delimited captions.
Example:
- Original: "unreal engine, render of sci-fi helmet, dramatic lighting"
- Shuffled: "render of sci-fi helmet, unreal engine, dramatic lighting"
"""
def __init__(self, field_name: str, delimiter: str = ",", seed: int = 0):
self._field_name = field_name
self._delimiter = delimiter
self._rng = np.random.default_rng(seed)
def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
caption: str = data[self._field_name]
caption_chunks = caption.split(self._delimiter)
caption_chunks = [s.strip() for s in caption_chunks]
self._rng.shuffle(caption_chunks)
join_str = self._delimiter + " "
data[self._field_name] = join_str.join(caption_chunks)
return data
================================================
FILE: src/invoke_training/_shared/data/transforms/template_caption_transform.py
================================================
import typing
import numpy as np
class TemplateCaptionTransform:
"""A simple transform that constructs a caption for each example by combining a caption template with the
placeholder string.
"""
def __init__(self, field_name: str, placeholder_str: str, caption_templates: list[str], seed: int = 0):
self._field_name = field_name
self._placeholder_str = placeholder_str
self._caption_templates = caption_templates
self._rng = np.random.default_rng(seed)
def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
caption = self._rng.choice(self._caption_templates).format(self._placeholder_str)
# Assert that the template was well-formed such that the placeholder string is in the output caption.
assert self._placeholder_str in caption
data[self._field_name] = caption
return data
================================================
FILE: src/invoke_training/_shared/data/transforms/tensor_disk_cache.py
================================================
import os
import typing
import torch
class TensorDiskCache:
"""A data cache that caches `torch.Tensor`s on disk."""
def __init__(self, cache_dir: str):
super().__init__()
self._cache_dir = cache_dir
os.makedirs(self._cache_dir, exist_ok=True)
def _get_path(self, key: int):
"""Get the cache file path for `key`.
Args:
key (int): The cache key.
Returns:
str: The cache file path.
"""
return os.path.join(self._cache_dir, f"{key}.pt")
def save(self, key: int, data: typing.Dict[str, torch.Tensor]):
"""Save data in the cache.
Raises:
AssertionError: If an entry already exists in the cache for this `key`.
Args:
key (int): The cache key.
data (typing.Dict[str, torch.Tensor]): The data to save.
"""
# torch.save() supports a range of different data types, but it is cleaner if we force everyone to use a dict.
# This allows for more reusable cache loading code.
assert isinstance(data, dict)
save_path = self._get_path(key)
assert not os.path.exists(save_path)
torch.save(data, save_path)
def load(self, key: int) -> typing.Dict[str, torch.Tensor]:
"""Load data from the cache.
Args:
key (int): The cache key to load.
Returns:
typing.Dict[str, torch.Tensor]: Data loaded from the cache.
"""
return torch.load(self._get_path(key))
================================================
FILE: src/invoke_training/_shared/data/utils/__init__.py
================================================
================================================
FILE: src/invoke_training/_shared/data/utils/aspect_ratio_bucket_manager.py
================================================
from invoke_training._shared.data.utils.resolution import Resolution
class AspectRatioBucketManager:
def __init__(self, buckets: set[Resolution]):
self.buckets = buckets
@classmethod
def from_constraints(cls, target_resolution: int, start_dim: int, end_dim: int, divisible_by: int) -> None:
buckets = cls.build_aspect_ratio_buckets(
target_resolution=target_resolution,
start_dim=start_dim,
end_dim=end_dim,
divisible_by=divisible_by,
)
return cls(buckets)
@classmethod
def build_aspect_ratio_buckets(
cls, target_resolution: int, start_dim: int, end_dim: int, divisible_by: int
) -> set[Resolution]:
"""Prepare a set of aspect ratios.
Args:
target_resolution (Resolution): All resolutions in the returned set will aim to have close to
(but <=) `target_resolution * target_resolution` pixels.
start_dim (int):
end_dim (int):
divisible_by (int): All dimensions in the returned set of resolutions will be divisible by `divisible_by`.
Returns:
set[tuple[int, int]]: The aspect ratio bucket resolutions.
"""
# Validate target_resolution.
assert target_resolution % divisible_by == 0
# Validate start_dim, end_dim.
assert start_dim <= end_dim
assert start_dim % divisible_by == 0
assert end_dim % divisible_by == 0
target_size = target_resolution * target_resolution
buckets = set()
height = start_dim
while height <= end_dim:
width = (target_size // height) // divisible_by * divisible_by
buckets.add(Resolution(height, width))
buckets.add(Resolution(width, height))
height += divisible_by
return buckets
def get_aspect_ratio_bucket(self, resolution: Resolution):
"""Get the bucket with the closest aspect ratio to 'resolution'."""
# Note: If this is ever found to be a bottleneck, there is a clearly-more-efficient implementation using bisect.
return min(self.buckets, key=lambda x: abs(x.aspect_ratio() - resolution.aspect_ratio()))
================================================
FILE: src/invoke_training/_shared/data/utils/resize.py
================================================
import math
from PIL.Image import Image
from torchvision import transforms
from invoke_training._shared.data.utils.resolution import Resolution
def resize_to_cover(image: Image, size_to_cover: Resolution) -> Image:
"""Resize image to the smallest size that covers 'size_to_cover' while preserving its aspect ratio.
In other words, achieve the following:
- resized_height >= size_to_cover.height
- resized_width >= size_to_cover.width
- resized_height == size_to_cover.height or resized_width == size_to_cover.width
- 'image' aspect ratio is preserved.
"""
scale_to_height = size_to_cover.height / image.height
scale_to_width = size_to_cover.width / image.width
if scale_to_height > scale_to_width:
resize_height = size_to_cover.height
resize_width = math.ceil(image.width * scale_to_height)
else:
resize_width = size_to_cover.width
resize_height = math.ceil(image.height * scale_to_width)
resize_transform = transforms.Resize(
(resize_height, resize_width), interpolation=transforms.InterpolationMode.BILINEAR
)
return resize_transform(image)
================================================
FILE: src/invoke_training/_shared/data/utils/resolution.py
================================================
from typing import Union
class Resolution:
def __init__(self, height: int, width: int):
self.height = height
self.width = width
@classmethod
def parse(cls, resolution: Union[int, tuple[int, int], "Resolution"]):
"""Initialize a Resolution object from another type."""
if isinstance(resolution, int):
# Assume square resolution.
return cls(resolution, resolution)
elif isinstance(resolution, tuple):
height, width = resolution
return cls(height, width)
elif isinstance(resolution, cls):
return cls(resolution.height, resolution.width)
else:
raise ValueError(f"Unsupported resolution type: '{type(resolution)}'.")
def aspect_ratio(self):
return self.height / self.width
def to_tuple(self) -> tuple[int, int]:
return (self.height, self.width)
def __eq__(self, other: "Resolution") -> bool:
return self.to_tuple() == other.to_tuple()
def __lt__(self, other: "Resolution") -> bool:
return self.to_tuple() < other.to_tuple()
def __hash__(self):
return hash(self.to_tuple())
================================================
FILE: src/invoke_training/_shared/flux/encoding_utils.py
================================================
import logging
from typing import List, Optional, Tuple, Union
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
def get_clip_prompt_embeds(
prompt: Union[str, List[str]],
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
device: torch.device,
num_images_per_prompt: int = 1,
tokenizer_max_length: int = 77,
logger: logging.Logger | None = None,
) -> torch.FloatTensor:
"""Encodes the prompt using CLIP text encoder and returns pooled embeddings."""
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
# Process text input with the tokenizer
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer_max_length,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
# Check if truncation occurred
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1])
if logger is not None:
logger.warning(f"Warning: The following part of your input was truncated: {removed_text}")
# Get prompt embeddings through the text encoder
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
# Duplicate text embeddings for each generation per prompt
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds
def get_t5_prompt_embeds(
prompt: Union[str, List[str]],
tokenizer: T5TokenizerFast,
text_encoder: T5EncoderModel,
device: torch.device,
num_images_per_prompt: int = 1,
tokenizer_max_length: int = 512,
logger: logging.Logger | None = None,
) -> torch.FloatTensor:
"""Encodes the prompt using T5 text encoder."""
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
# Process text input with the tokenizer
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer_max_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
# Check if truncation occurred
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1])
if logger is not None:
logger.warning(f"Warning: The following part of your input was truncated: {removed_text}")
# Get prompt embeddings through the text encoder
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)[0]
dtype = text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# Get shape and duplicate for multiple generations
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds
def handle_lora_scale(
clip_text_encoder: CLIPTextModel,
t5_text_encoder: T5EncoderModel,
lora_scale: Optional[float] = None,
use_peft_backend: bool = False,
):
"""Handles LoRA scale adjustments for text encoders."""
if lora_scale is not None and use_peft_backend:
from peft.utils import scale_lora_layers
# Apply LoRA scaling to text encoders if they exist
if clip_text_encoder is not None:
scale_lora_layers(clip_text_encoder, lora_scale)
if t5_text_encoder is not None:
scale_lora_layers(t5_text_encoder, lora_scale)
return True
return False
def reset_lora_scale(
clip_text_encoder: CLIPTextModel,
t5_text_encoder: T5EncoderModel,
lora_scale: Optional[float] = None,
lora_applied: bool = False,
use_peft_backend: bool = False,
):
"""Resets LoRA scale for text encoders if it was applied."""
if lora_applied and use_peft_backend:
from peft.utils import unscale_lora_layers
# Reset LoRA scaling
if clip_text_encoder is not None:
unscale_lora_layers(clip_text_encoder, lora_scale)
if t5_text_encoder is not None:
unscale_lora_layers(t5_text_encoder, lora_scale)
# A lot of this code was adapted from:
# https://github.com/huggingface/diffusers/blob/ea81a4228d8ff16042c3ccaf61f0e588e60166cd/src/diffusers/pipelines/flux/pipeline_flux.py#L310-L387
def encode_prompt(
prompt: Union[str, List[str]],
prompt_2: Optional[Union[str, List[str]]],
clip_tokenizer: CLIPTokenizer,
t5_tokenizer: T5TokenizerFast,
clip_text_encoder: CLIPTextModel,
t5_text_encoder: T5EncoderModel,
device: torch.device,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
use_peft_backend: bool = False,
clip_tokenizer_max_length: int = 77,
t5_tokenizer_max_length: int = 512,
logger: logging.Logger | None = None,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""
Encodes the prompt using both CLIP and T5 text encoders.
Returns:
Tuple containing:
- T5 text embeddings
- CLIP pooled embeddings
- Text IDs
"""
# Apply LoRA scale if needed
lora_applied = handle_lora_scale(
clip_text_encoder=clip_text_encoder,
t5_text_encoder=t5_text_encoder,
lora_scale=lora_scale,
use_peft_backend=use_peft_backend,
)
# If no pre-generated embeddings, create them
if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
# Get CLIP pooled embeddings
pooled_prompt_embeds = get_clip_prompt_embeds(
prompt=prompt,
tokenizer=clip_tokenizer,
text_encoder=clip_text_encoder,
device=device,
num_images_per_prompt=num_images_per_prompt,
tokenizer_max_length=clip_tokenizer_max_length,
)
# Get T5 text embeddings
prompt_embeds = get_t5_prompt_embeds(
prompt=prompt_2,
tokenizer=t5_tokenizer,
text_encoder=t5_text_encoder,
device=device,
num_images_per_prompt=num_images_per_prompt,
tokenizer_max_length=t5_tokenizer_max_length,
)
# Reset LoRA scale if it was applied
reset_lora_scale(
clip_text_encoder=clip_text_encoder,
t5_text_encoder=t5_text_encoder,
lora_scale=lora_scale,
lora_applied=lora_applied,
use_peft_backend=use_peft_backend,
)
# Create text_ids placeholder for model
dtype = clip_text_encoder.dtype if clip_text_encoder is not None else t5_text_encoder.dtype
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
return prompt_embeds, pooled_prompt_embeds, text_ids
================================================
FILE: src/invoke_training/_shared/flux/lora_checkpoint_utils.py
================================================
# ruff: noqa: N806
import os
from pathlib import Path
import peft
import torch
from diffusers import FluxTransformer2DModel
from transformers import CLIPTextModel
from invoke_training._shared.checkpoints.lora_checkpoint_utils import (
_convert_peft_state_dict_to_kohya_state_dict,
load_multi_model_peft_checkpoint,
save_multi_model_peft_checkpoint,
)
from invoke_training._shared.checkpoints.serialization import save_state_dict
FLUX_TRANSFORMER_TARGET_MODULES = [
# double blocks
"attn.add_k_proj",
"attn.add_q_proj",
"attn.add_v_proj",
"attn.to_add_out",
"attn.to_k",
"attn.to_q",
"attn.to_v",
"attn.to_out.0",
"ff.net.0.proj",
"ff.net.2.0",
"ff_context.net.0.proj",
"ff_context.net.2.0",
# single blocks
"attn.to_k",
"attn.to_q",
"attn.to_v",
"proj_mlp",
"proj_out",
"proj_in",
]
TEXT_ENCODER_TARGET_MODULES = ["fc1", "fc2", "q_proj", "k_proj", "v_proj", "out_proj"]
# Module lists copied from diffusers training script.
# These module lists will produce lighter, less expressive, LoRA models than the non-light versions.
FLUX_TRANSFORMER_TARGET_MODULES_LIGHT = ["to_k", "to_q", "to_v", "to_out.0"]
FLUX_TEXT_ENCODER_TARGET_MODULES_LIGHT = ["q_proj", "k_proj", "v_proj", "out_proj"]
FLUX_PEFT_TRANSFORMER_KEY = "transformer"
FLUX_PEFT_TEXT_ENCODER_1_KEY = "text_encoder_1"
FLUX_PEFT_TEXT_ENCODER_2_KEY = "text_encoder_2"
FLUX_KOHYA_TRANSFORMER_KEY = "lora_unet"
FLUX_KOHYA_TEXT_ENCODER_1_KEY = "lora_clip"
FLUX_KOHYA_TEXT_ENCODER_2_KEY = "lora_t5"
FLUX_PEFT_TO_KOHYA_KEYS = {
FLUX_PEFT_TRANSFORMER_KEY: FLUX_KOHYA_TRANSFORMER_KEY,
FLUX_PEFT_TEXT_ENCODER_1_KEY: FLUX_KOHYA_TEXT_ENCODER_1_KEY,
FLUX_PEFT_TEXT_ENCODER_2_KEY: FLUX_KOHYA_TEXT_ENCODER_2_KEY,
}
def save_flux_peft_checkpoint(
checkpoint_dir: Path | str,
transformer: peft.PeftModel | None,
text_encoder_1: peft.PeftModel | None,
text_encoder_2: peft.PeftModel | None,
):
models = {}
if transformer is not None:
models[FLUX_PEFT_TRANSFORMER_KEY] = transformer
if text_encoder_1 is not None:
models[FLUX_PEFT_TEXT_ENCODER_1_KEY] = text_encoder_1
if text_encoder_2 is not None:
models[FLUX_PEFT_TEXT_ENCODER_2_KEY] = text_encoder_2
save_multi_model_peft_checkpoint(checkpoint_dir=checkpoint_dir, models=models)
def load_flux_peft_checkpoint(
checkpoint_dir: Path | str,
transformer: FluxTransformer2DModel,
text_encoder_1: CLIPTextModel,
text_encoder_2: CLIPTextModel,
is_trainable: bool = False,
):
models = load_multi_model_peft_checkpoint(
checkpoint_dir=checkpoint_dir,
models={
FLUX_PEFT_TRANSFORMER_KEY: transformer,
FLUX_PEFT_TEXT_ENCODER_1_KEY: text_encoder_1,
FLUX_PEFT_TEXT_ENCODER_2_KEY: text_encoder_2,
},
is_trainable=is_trainable,
raise_if_subdir_missing=False,
)
return models[FLUX_PEFT_TRANSFORMER_KEY], models[FLUX_PEFT_TEXT_ENCODER_1_KEY], models[FLUX_PEFT_TEXT_ENCODER_2_KEY]
def save_flux_kohya_checkpoint(
checkpoint_path: Path,
transformer: peft.PeftModel | None,
text_encoder_1: peft.PeftModel | None,
text_encoder_2: peft.PeftModel | None,
):
kohya_prefixes = []
models = []
for kohya_prefix, peft_model in zip(
[FLUX_KOHYA_TRANSFORMER_KEY, FLUX_KOHYA_TEXT_ENCODER_1_KEY], [transformer, text_encoder_1]
):
if peft_model is not None:
kohya_prefixes.append(kohya_prefix)
models.append(peft_model)
kohya_state_dict = _convert_peft_models_to_kohya_state_dict(kohya_prefixes=kohya_prefixes, models=models)
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
save_state_dict(kohya_state_dict, checkpoint_path)
def convert_flux_peft_checkpoint_to_kohya_state_dict(
in_checkpoint_dir: Path,
out_checkpoint_file: Path,
dtype: torch.dtype = torch.float32,
) -> dict[str, torch.Tensor]:
"""Convert Flux PEFT models to a Kohya-format LoRA state dict."""
# Get the immediate subdirectories of the checkpoint directory. We assume that each subdirectory is a PEFT model.
peft_model_dirs = os.listdir(in_checkpoint_dir)
peft_model_dirs = [in_checkpoint_dir / d for d in peft_model_dirs] # Convert to Path objects.
peft_model_dirs = [d for d in peft_model_dirs if d.is_dir()] # Filter out non-directories.
if len(peft_model_dirs) == 0:
raise ValueError(f"No checkpoint files found in directory '{in_checkpoint_dir}'.")
kohya_state_dict = {}
for peft_model_dir in peft_model_dirs:
if peft_model_dir.name in FLUX_PEFT_TO_KOHYA_KEYS:
kohya_prefix = FLUX_PEFT_TO_KOHYA_KEYS[peft_model_dir.name]
else:
raise ValueError(f"Unrecognized checkpoint directory: '{peft_model_dir}'.")
# Note: This logic to load the LoraConfig and weights directly is based on how it is done here:
# https://github.com/huggingface/peft/blob/8665e2b5719faa4e4b91749ddec09442927b53e0/src/peft/peft_model.py#L672-L689
# This may need to be updated in the future to support other adapter types (LoKr, LoHa, etc.).
# Also, I could see this interface breaking in the future.
lora_config = peft.LoraConfig.from_pretrained(peft_model_dir)
lora_weights = peft.utils.load_peft_weights(peft_model_dir, device="cpu")
kohya_state_dict.update(
_convert_peft_state_dict_to_kohya_state_dict(
lora_config=lora_config, peft_state_dict=lora_weights, prefix=kohya_prefix, dtype=dtype
)
)
save_state_dict(kohya_state_dict, out_checkpoint_file)
def _convert_peft_models_to_kohya_state_dict(
kohya_prefixes: list[str], models: list[peft.PeftModel]
) -> dict[str, torch.Tensor]:
kohya_state_dict = {}
default_adapter_name = "default"
for kohya_prefix, peft_model in zip(kohya_prefixes, models, strict=True):
lora_config = peft_model.peft_config[default_adapter_name]
assert isinstance(lora_config, peft.LoraConfig)
state_dict = peft.get_peft_model_state_dict(peft_model, adapter_name=default_adapter_name)
if kohya_prefix == FLUX_KOHYA_TRANSFORMER_KEY:
state_dict = convert_diffusers_to_flux_transformer_checkpoint(state_dict)
kohya_state_dict.update(
_convert_peft_state_dict_to_kohya_state_dict(
lora_config=lora_config,
peft_state_dict=state_dict,
prefix=kohya_prefix,
dtype=torch.float32,
)
)
return kohya_state_dict
def find_matching_key_prefix(state_dict, key_pattern):
"""
Find if any key in the state dictionary matches the given pattern.
Args:
state_dict: The state dictionary to search in
key_pattern: The pattern to look for in keys
Returns:
The matching prefix if found, False otherwise
"""
base_prefix = key_pattern.split(".lora_A")[0].split(".lora_B")[0].split(".weight")[0]
for key in state_dict.keys():
if base_prefix in key:
return base_prefix
return False
def convert_layer_weights(target_dict, source_dict, source_pattern, target_pattern):
"""
Convert weights from source pattern to target pattern if they exist.
Args:
target_dict: Dictionary to store converted weights
source_dict: Source dictionary containing weights
source_pattern: Original key pattern to search for
target_pattern: New key pattern to use
Returns:
Tuple of (updated target_dict, updated source_dict)
"""
if original_key := find_matching_key_prefix(source_dict, source_pattern):
# Find all keys matching the pattern
keys_to_convert = [k for k in source_dict.keys() if original_key in k]
for key in keys_to_convert:
# Create replacement key
new_key = key.replace(original_key, target_pattern.replace(".weight", ""))
# Transfer and remove from original
target_dict[new_key] = source_dict.pop(key)
return target_dict, source_dict
def convert_double_transformer_block(target_dict, source_dict, prefix="", block_idx=0):
"""
Convert weights for a double transformer block.
Args:
target_dict: Dictionary to store converted weights
source_dict: Source dictionary containing weights
prefix: Prefix for the keys in the state dictionary
block_idx: Block index
Returns:
Tuple of (updated target_dict, updated source_dict)
"""
block_prefix = f"transformer_blocks.{block_idx}."
# Convert norms
target_dict, source_dict = convert_layer_weights(
target_dict,
source_dict,
f"{prefix}{block_prefix}norm1.linear.weight",
f"double_blocks.{block_idx}.img_mod.lin.weight",
)
target_dict, source_dict = convert_layer_weights(
target_dict,
source_dict,
f"{prefix}{block_prefix}norm1_context.linear.weight",
f"double_blocks.{block_idx}.txt_mod.lin.weight",
)
# Convert attention weights by concatenating Q, K, V
try:
# Sample attention weights
sample_q_A = source_dict.pop(f"{prefix}{block_prefix}attn.to_q.lora_A.weight")
sample_q_B = source_dict.pop(f"{prefix}{block_prefix}attn.to_q.lora_B.weight")
sample_k_A = source_dict.pop(f"{prefix}{block_prefix}attn.to_k.lora_A.weight")
sample_k_B = source_dict.pop(f"{prefix}{block_prefix}attn.to_k.lora_B.weight")
sample_v_A = source_dict.pop(f"{prefix}{block_prefix}attn.to_v.lora_A.weight")
sample_v_B = source_dict.pop(f"{prefix}{block_prefix}attn.to_v.lora_B.weight")
# Context attention weights
context_q_A = source_dict.pop(f"{prefix}{block_prefix}attn.add_q_proj.lora_A.weight")
context_q_B = source_dict.pop(f"{prefix}{block_prefix}attn.add_q_proj.lora_B.weight")
context_k_A = source_dict.pop(f"{prefix}{block_prefix}attn.add_k_proj.lora_A.weight")
context_k_B = source_dict.pop(f"{prefix}{block_prefix}attn.add_k_proj.lora_B.weight")
context_v_A = source_dict.pop(f"{prefix}{block_prefix}attn.add_v_proj.lora_A.weight")
context_v_B = source_dict.pop(f"{prefix}{block_prefix}attn.add_v_proj.lora_B.weight")
# Concatenate Q, K, V for image and text
target_dict[f"double_blocks.{block_idx}.img_attn.qkv.lora_A.weight"] = torch.cat(
[sample_q_A, sample_k_A, sample_v_A], dim=0
)
target_dict[f"double_blocks.{block_idx}.img_attn.qkv.lora_B.weight"] = torch.cat(
[sample_q_B, sample_k_B, sample_v_B], dim=0
)
target_dict[f"double_blocks.{block_idx}.txt_attn.qkv.lora_A.weight"] = torch.cat(
[context_q_A, context_k_A, context_v_A], dim=0
)
target_dict[f"double_blocks.{block_idx}.txt_attn.qkv.lora_B.weight"] = torch.cat(
[context_q_B, context_k_B, context_v_B], dim=0
)
except KeyError as e:
print(f"Error processing attention weights for block {block_idx}: {e}")
raise
# Convert QK norms
norm_keys = [
(f"{prefix}{block_prefix}attn.norm_q.weight", f"double_blocks.{block_idx}.img_attn.norm.query_norm.scale"),
(f"{prefix}{block_prefix}attn.norm_k.weight", f"double_blocks.{block_idx}.img_attn.norm.key_norm.scale"),
(
f"{prefix}{block_prefix}attn.norm_added_q.weight",
f"double_blocks.{block_idx}.txt_attn.norm.query_norm.scale",
),
(f"{prefix}{block_prefix}attn.norm_added_k.weight", f"double_blocks.{block_idx}.txt_attn.norm.key_norm.scale"),
]
for src_key, target_key in norm_keys:
target_dict, source_dict = convert_layer_weights(target_dict, source_dict, src_key, target_key)
# Convert MLP weights
mlp_keys = [
(f"{prefix}{block_prefix}ff.net.0.proj.weight", f"double_blocks.{block_idx}.img_mlp.0.weight"),
(f"{prefix}{block_prefix}ff.net.2.weight", f"double_blocks.{block_idx}.img_mlp.2.weight"),
(f"{prefix}{block_prefix}ff_context.net.0.proj.weight", f"double_blocks.{block_idx}.txt_mlp.0.weight"),
(f"{prefix}{block_prefix}ff_context.net.2.weight", f"double_blocks.{block_idx}.txt_mlp.2.weight"),
]
for src_key, target_key in mlp_keys:
target_dict, source_dict = convert_layer_weights(target_dict, source_dict, src_key, target_key)
# Convert output projections
output_keys = [
(f"{prefix}{block_prefix}attn.to_out.0.weight", f"double_blocks.{block_idx}.img_attn.proj.weight"),
(f"{prefix}{block_prefix}attn.to_add_out.weight", f"double_blocks.{block_idx}.txt_attn.proj.weight"),
]
for src_key, target_key in output_keys:
target_dict, source_dict = convert_layer_weights(target_dict, source_dict, src_key, target_key)
return target_dict, source_dict
def convert_single_transformer_block(target_dict, source_dict, prefix, block_idx):
"""
Convert weights for a single transformer block.
Args:
target_dict: Dictionary to store converted weights
source_dict: Source dictionary containing weights
prefix: Prefix for the keys in the state dictionary
block_idx: Block index
Returns:
Tuple of (updated target_dict, updated source_dict)
"""
block_prefix = f"single_transformer_blocks.{block_idx}."
# Convert norm
target_dict, source_dict = convert_layer_weights(
target_dict,
source_dict,
f"{prefix}{block_prefix}norm.linear.weight",
f"single_blocks.{block_idx}.modulation.lin.weight",
)
try:
# Convert Q, K, V, MLP by concatenating
q_A = source_dict.pop(f"{prefix}{block_prefix}attn.to_q.lora_A.weight")
q_B = source_dict.pop(f"{prefix}{block_prefix}attn.to_q.lora_B.weight")
k_A = source_dict.pop(f"{prefix}{block_prefix}attn.to_k.lora_A.weight")
k_B = source_dict.pop(f"{prefix}{block_prefix}attn.to_k.lora_B.weight")
v_A = source_dict.pop(f"{prefix}{block_prefix}attn.to_v.lora_A.weight")
v_B = source_dict.pop(f"{prefix}{block_prefix}attn.to_v.lora_B.weight")
mlp_A = source_dict.pop(f"{prefix}{block_prefix}proj_mlp.lora_A.weight")
mlp_B = source_dict.pop(f"{prefix}{block_prefix}proj_mlp.lora_B.weight")
target_dict[f"single_blocks.{block_idx}.linear1.lora_A.weight"] = torch.cat([q_A, k_A, v_A, mlp_A], dim=0)
target_dict[f"single_blocks.{block_idx}.linear1.lora_B.weight"] = torch.cat([q_B, k_B, v_B, mlp_B], dim=0)
except KeyError as e:
print(f"Error processing attention weights for single block {block_idx}: {e}")
raise
# Convert output projection
target_dict, source_dict = convert_layer_weights(
target_dict,
source_dict,
f"{prefix}{block_prefix}proj_out.weight",
f"single_blocks.{block_idx}.linear2.weight",
)
return target_dict, source_dict
def convert_embedding_layers(target_dict, source_dict, prefix, has_guidance=True):
"""
Convert time, text, guidance, and context embedding layers.
Args:
target_dict: Dictionary to store converted weights
source_dict: Source dictionary containing weights
prefix: Prefix for the keys in the state dictionary
has_guidance: Whether the model has guidance embedding
Returns:
Tuple of (updated target_dict, updated source_dict)
"""
# Convert time embedding
target_dict, source_dict = convert_layer_weights(
target_dict,
source_dict,
f"{prefix}time_text_embed.timestep_embedder.linear_1.weight",
"time_in.in_layer.weight",
)
# Convert text embedding
text_embed_keys = [
(f"{prefix}time_text_embed.text_embedder.linear_1.weight", "vector_in.in_layer.weight"),
(f"{prefix}time_text_embed.text_embedder.linear_2.weight", "vector_in.out_layer.weight"),
]
for src_key, target_key in text_embed_keys:
target_dict, source_dict = convert_layer_weights(target_dict, source_dict, src_key, target_key)
# Convert guidance embedding if needed
if has_guidance:
guidance_keys = [
(f"{prefix}time_text_embed.guidance_embedder.linear_1.weight", "guidance_in.in_layer.weight"),
(f"{prefix}time_text_embed.guidance_embedder.linear_2.weight", "guidance_in.out_layer.weight"),
]
for src_key, target_key in guidance_keys:
target_dict, source_dict = convert_layer_weights(target_dict, source_dict, src_key, target_key)
# Convert context and image embedders
embed_keys = [
(f"{prefix}context_embedder.weight", "txt_in.weight"),
(f"{prefix}x_embedder.weight", "img_in.weight"),
]
for src_key, target_key in embed_keys:
target_dict, source_dict = convert_layer_weights(target_dict, source_dict, src_key, target_key)
return target_dict, source_dict
def convert_output_layers(target_dict, source_dict, prefix):
"""
Convert final output layers.
Args:
target_dict: Dictionary to store converted weights
source_dict: Source dictionary containing weights
prefix: Prefix for the keys in the state dictionary
Returns:
Tuple of (updated target_dict, updated source_dict)
"""
output_keys = [
(f"{prefix}proj_out.weight", "final_layer.linear.weight"),
(f"{prefix}proj_out.bias", "final_layer.linear.bias"),
(f"{prefix}norm_out.linear.weight", "final_layer.adaLN_modulation.1.weight"),
]
for src_key, target_key in output_keys:
target_dict, source_dict = convert_layer_weights(target_dict, source_dict, src_key, target_key)
return target_dict, source_dict
def convert_diffusers_to_flux_transformer_checkpoint(
diffusers_state_dict,
num_layers=19,
num_single_layers=38,
has_guidance=True,
old_prefix="base_model.model.",
new_prefix=FLUX_KOHYA_TRANSFORMER_KEY,
):
"""
Convert a diffusers state dictionary to flux transformer checkpoint format.
Args:
diffusers_state_dict: Source diffusers state dictionary
num_layers: Number of double transformer layers
num_single_layers: Number of single transformer layers
has_guidance: Whether the model has guidance embedding
prefix: Prefix for keys in the source dictionary
Returns:
A new state dictionary in flux transformer format
"""
# Create a new state dictionary
flux_state_dict = {}
# Convert embedding layers
flux_state_dict, diffusers_state_dict = convert_embedding_layers(
flux_state_dict, diffusers_state_dict, old_prefix, has_guidance
)
# Convert double transformer blocks
for i in range(num_layers):
flux_state_dict, diffusers_state_dict = convert_double_transformer_block(
flux_state_dict, diffusers_state_dict, old_prefix, i
)
# Convert single transformer blocks
for i in range(num_single_layers):
flux_state_dict, diffusers_state_dict = convert_single_transformer_block(
flux_state_dict, diffusers_state_dict, old_prefix, i
)
# Convert output layers
flux_state_dict, diffusers_state_dict = convert_output_layers(flux_state_dict, diffusers_state_dict, old_prefix)
# Check for leftover keys
if diffusers_state_dict:
print(f"Unexpected keys: {list(diffusers_state_dict.keys())}")
# Replace the old prefix with the new prefix
keys = list(flux_state_dict.keys())
for key in keys:
new_key = f"{new_prefix}.{key}"
flux_state_dict[new_key] = flux_state_dict.pop(key)
return flux_state_dict
================================================
FILE: src/invoke_training/_shared/flux/model_loading_utils.py
================================================
import logging
from enum import Enum
import torch
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
class PipelineVersionEnum(Enum):
FLUX = "FLUX"
def load_pipeline(
logger: logging.Logger,
model_name_or_path: str = "black-forest-labs/FLUX.1-dev",
pipeline_version: PipelineVersionEnum = PipelineVersionEnum.FLUX,
transformer_path: str | None = None,
text_encoder_1_path: str | None = None,
text_encoder_2_path: str | None = None,
torch_dtype: torch.dtype | None = None,
) -> FluxPipeline:
"""Load a Flux pipeline with optional custom components from .safetensors files.
Args:
logger: Logger instance
model_name_or_path: Base model path or repository
pipeline_version: Pipeline version (currently only FLUX supported)
transformer_path: Path to custom transformer .safetensors file
text_encoder_1_path: Path to custom CLIP text encoder .safetensors file
text_encoder_2_path: Path to custom T5 text encoder .safetensors file
torch_dtype: Desired dtype for the models
Returns:
FluxPipeline: Configured pipeline with custom components if specified
"""
if pipeline_version != PipelineVersionEnum.FLUX:
raise ValueError(f"Invalid pipeline version: {pipeline_version}")
# Prepare kwargs for from_pretrained
kwargs = {"torch_dtype": torch_dtype}
# Add components only if custom paths are provided
if transformer_path is not None:
# load_model_from_file_or_pretrained(FluxTransformer2DModel, transformer_path, torch_dtype=torch_dtype,
# use_safetensors=True, subfolder="transformer")
kwargs["transformer"] = FluxTransformer2DModel.from_pretrained(
transformer_path,
torch_dtype=torch_dtype,
)
logger.info(f"Loading custom transformer from {transformer_path}")
if text_encoder_1_path is not None:
logger.info(f"Loading custom CLIP text encoder from {text_encoder_1_path}")
kwargs["text_encoder"] = CLIPTextModel.from_pretrained(text_encoder_1_path, torch_dtype=torch_dtype)
if text_encoder_2_path is not None:
logger.info(f"Loading custom T5 text encoder from {text_encoder_2_path}")
kwargs["text_encoder_2"] = T5EncoderModel.from_pretrained(text_encoder_2_path, torch_dtype=torch_dtype)
# Load the pipeline with any custom components
pipeline = FluxPipeline.from_pretrained(model_name_or_path, **kwargs)
return pipeline
def load_models_flux(
logger: logging.Logger,
model_name_or_path: str = "black-forest-labs/FLUX.1-dev",
dtype: torch.dtype | None = None,
transformer_path: str | None = None,
text_encoder_1_path: str | None = None,
text_encoder_2_path: str | None = None,
) -> tuple[CLIPTokenizer, FlowMatchEulerDiscreteScheduler, CLIPTextModel, AutoencoderKL, FluxTransformer2DModel]:
"""Load all models required for training from disk, transfer them to the
target training device and cast their weight dtypes.
Args:
logger: Logger instance
model_name_or_path: Base model path or repository
dtype: Desired dtype for the models
transformer_path: Path to custom transformer .safetensors file
text_encoder_1_path: Path to custom CLIP text encoder .safetensors file
text_encoder_2_path: Path to custom T5 text encoder .safetensors file
"""
pipeline: FluxPipeline = load_pipeline(
logger=logger,
model_name_or_path=model_name_or_path,
pipeline_version=PipelineVersionEnum.FLUX,
transformer_path=transformer_path,
text_encoder_1_path=text_encoder_1_path,
text_encoder_2_path=text_encoder_2_path,
torch_dtype=dtype,
)
# Tokenizers and text encoders.
tokenizer_1: CLIPTokenizer = pipeline.tokenizer
text_encoder_1: CLIPTextModel = pipeline.text_encoder
tokenizer_2: T5Tokenizer = pipeline.tokenizer_2
text_encoder_2: T5EncoderModel = pipeline.text_encoder_2
# Transformer and Scheduler
transformer: FluxTransformer2DModel = pipeline.transformer
noise_scheduler: FlowMatchEulerDiscreteScheduler = pipeline.scheduler
# Decoder
vae: AutoencoderKL = pipeline.vae
# Log component status
logger.info(
f"Pipeline components loaded: tokenizer_1={tokenizer_1 is not None}, "
f"text_encoder_1={text_encoder_1 is not None}, "
f"tokenizer_2={tokenizer_2 is not None}, "
f"text_encoder_2={text_encoder_2 is not None}, "
f"transformer={transformer is not None}, "
f"vae={vae is not None}"
)
# Check for None components
if text_encoder_1 is None:
raise ValueError(
"text_encoder_1 failed to load. "
"Check if you have access to the model repository and are properly authenticated."
)
if text_encoder_2 is None:
raise ValueError(
"text_encoder_2 failed to load. "
"Check if you have access to the model repository and are properly authenticated."
)
if transformer is None:
raise ValueError(
"transformer failed to load. "
"Check if you have access to the model repository and are properly authenticated."
)
if vae is None:
raise ValueError(
"vae failed to load. Check if you have access to the model repository and are properly authenticated."
)
# Disable gradient calculation for model weights to save memory.
text_encoder_1.requires_grad_(False)
text_encoder_2.requires_grad_(False)
vae.requires_grad_(False)
transformer.requires_grad_(False)
if dtype is not None:
text_encoder_1 = text_encoder_1.to(dtype=dtype)
text_encoder_2 = text_encoder_2.to(dtype=dtype)
vae = vae.to(dtype=dtype)
transformer = transformer.to(dtype=dtype)
# Put models in 'eval' mode.
text_encoder_1.eval()
text_encoder_2.eval()
vae.eval()
transformer.eval()
return tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, transformer
================================================
FILE: src/invoke_training/_shared/flux/validation.py
================================================
import logging
import os
import numpy as np
import torch
import torch.utils.data
from accelerate import Accelerator
from accelerate.hooks import remove_hook_from_module
from diffusers import (
AutoencoderKL,
FlowMatchEulerDiscreteScheduler,
FluxPipeline,
FluxTransformer2DModel,
)
from peft import PeftModel
from transformers import CLIPTextModel, CLIPTokenizer
from invoke_training._shared.data.utils.resolution import Resolution
from invoke_training.pipelines.callbacks import PipelineCallbacks, ValidationImage, ValidationImages
from invoke_training.pipelines.flux.lora.config import FluxLoraConfig
NUM_INFERENCE_STEPS = 20
def generate_validation_images_flux( # noqa: C901
epoch: int,
step: int,
out_dir: str,
accelerator: Accelerator,
vae: AutoencoderKL,
text_encoder_1: CLIPTextModel,
text_encoder_2: CLIPTextModel,
tokenizer_1: CLIPTokenizer,
tokenizer_2: CLIPTokenizer,
noise_scheduler: FlowMatchEulerDiscreteScheduler,
transformer: FluxTransformer2DModel | PeftModel,
config: FluxLoraConfig,
logger: logging.Logger,
callbacks: list[PipelineCallbacks] | None = None,
):
"""Generate validation images for the purpose of tracking image generation behaviour on fixed prompts throughout
training.
"""
# Record original model devices so that we can restore this state after running the pipeline with CPU model
# offloading.
transformer_device = transformer.device
vae_device = vae.device
text_encoder_1_device = text_encoder_1.device
text_encoder_2_device = text_encoder_2.device
# Create pipeline.
pipeline = FluxPipeline(
vae=vae,
text_encoder=text_encoder_1,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer_1,
tokenizer_2=tokenizer_2,
transformer=transformer,
scheduler=noise_scheduler,
)
if config.enable_cpu_offload_during_validation:
pipeline.enable_model_cpu_offload(accelerator.device.index or 0)
else:
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
validation_resolution = Resolution.parse(config.data_loader.resolution)
validation_images = ValidationImages(images=[], epoch=epoch, step=step)
validation_step_dir = os.path.join(out_dir, "validation", f"epoch_{epoch:0>8}-step_{step:0>8}")
logger.info(f"Generating validation images ({validation_step_dir}).")
# Run inference.
with torch.no_grad():
for prompt_idx in range(len(config.validation_prompts)):
positive_prompt = config.validation_prompts[prompt_idx]
negative_prompt = None
logger.info(f"Validation prompt {prompt_idx}, pos: '{positive_prompt}', neg: '{negative_prompt or ''}'")
generator = torch.Generator(device=accelerator.device)
if config.seed is not None:
generator = generator.manual_seed(config.seed)
images = []
for _ in range(config.num_validation_images_per_prompt):
with accelerator.autocast():
images.append(
pipeline(
positive_prompt,
num_inference_steps=NUM_INFERENCE_STEPS,
generator=generator,
height=validation_resolution.height,
width=validation_resolution.width,
negative_prompt=negative_prompt,
).images[0]
)
# Save images to disk.
validation_prompt_dir = os.path.join(validation_step_dir, f"prompt_{prompt_idx:0>4}")
os.makedirs(validation_prompt_dir)
for image_idx, image in enumerate(images):
image_path = os.path.join(validation_prompt_dir, f"{image_idx:0>4}.jpg")
validation_images.images.append(
ValidationImage(file_path=image_path, prompt=positive_prompt, image_idx=image_idx)
)
image.save(image_path)
# Log images to trackers. Currently, only tensorboard is supported.
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images(
f"validation (prompt {prompt_idx})",
np_images,
step,
dataformats="NHWC",
)
del pipeline
torch.cuda.empty_cache()
for model in [transformer, vae, text_encoder_1, text_encoder_2]:
remove_hook_from_module(model)
# Restore models to original devices.
transformer.to(transformer_device)
vae.to(vae_device)
text_encoder_1.to(text_encoder_1_device)
text_encoder_2.to(text_encoder_2_device)
# Run callbacks.
if callbacks is not None:
for cb in callbacks:
cb.on_save_validation_images(images=validation_images)
================================================
FILE: src/invoke_training/_shared/optimizer/__init__.py
================================================
================================================
FILE: src/invoke_training/_shared/optimizer/optimizer_utils.py
================================================
import torch
from prodigyopt import Prodigy
from invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig, ProdigyOptimizerConfig
def initialize_optimizer(
config: AdamOptimizerConfig | ProdigyOptimizerConfig, trainable_params: list
) -> torch.optim.Optimizer:
"""Initialize an optimizer based on the provided config."""
if config.optimizer_type == "AdamW":
adam_cls = torch.optim.AdamW
if config.use_8bit:
try:
import bitsandbytes # noqa: F401
except ImportError:
raise ImportError(
"bitsandbytes is not installed. bitsandbytes is required to use the 8-bit Adam optimizer. Install "
'it by running `pip install ".[bitsandbytes]"`.'
)
adam_cls = bitsandbytes.optim.AdamW8bit
optimizer = adam_cls(
trainable_params,
lr=config.learning_rate,
betas=(config.beta1, config.beta2),
weight_decay=config.weight_decay,
eps=config.epsilon,
)
elif config.optimizer_type == "Prodigy":
optimizer = Prodigy(
trainable_params,
lr=config.learning_rate,
weight_decay=config.weight_decay,
use_bias_correction=config.use_bias_correction,
safeguard_warmup=config.safeguard_warmup,
)
else:
raise ValueError(f"'{config.optimizer_type}' is not a supported optimizer.")
return optimizer
================================================
FILE: src/invoke_training/_shared/stable_diffusion/__init__.py
================================================
================================================
FILE: src/invoke_training/_shared/stable_diffusion/base_model_version.py
================================================
from enum import Enum
from transformers import PretrainedConfig
class BaseModelVersionEnum(Enum):
STABLE_DIFFUSION_V1 = 1
STABLE_DIFFUSION_V2 = 2
STABLE_DIFFUSION_SDXL_BASE = 3
STABLE_DIFFUSION_SDXL_REFINER = 4
def get_base_model_version(
diffusers_model_name: str, revision: str = "main", local_files_only: bool = True
) -> BaseModelVersionEnum:
"""Returns the `BaseModelVersionEnum` of a diffusers model.
Args:
diffusers_model_name (str): The diffusers model name (on Hugging Face Hub).
revision (str, optional): The model revision (branch or commit hash). Defaults to "main".
Raises:
Exception: If the base model version can not be determined.
Returns:
BaseModelVersionEnum: The detected base model version.
"""
unet_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path=diffusers_model_name,
revision=revision,
subfolder="unet",
local_files_only=local_files_only,
)
# This logic was copied from
# https://github.com/invoke-ai/InvokeAI/blob/e77400ab62d24acbdf2f48a7427705e7b8b97e4a/invokeai/backend/model_management/model_probe.py#L412-L421
# This seems fragile. If you see this and know of a better way to detect the base model version, your contribution
# would be welcome.
if unet_config.cross_attention_dim == 768:
return BaseModelVersionEnum.STABLE_DIFFUSION_V1
elif unet_config.cross_attention_dim == 1024:
return BaseModelVersionEnum.STABLE_DIFFUSION_V2
elif unet_config.cross_attention_dim == 1280:
return BaseModelVersionEnum.STABLE_DIFFUSION_SDXL_REFINER
elif unet_config.cross_attention_dim == 2048:
return BaseModelVersionEnum.STABLE_DIFFUSION_SDXL_BASE
else:
raise Exception(
"Failed to determine base model version. UNet cross_attention_dim has unexpected value: "
f"'{unet_config.cross_attention_dim}'."
)
def check_base_model_version(
allowed_versions: set[BaseModelVersionEnum],
diffusers_model_name: str,
revision: str = "main",
local_files_only: bool = True,
):
"""Helper function that checks if a diffusers model is compatible with a set of base model versions.
Args:
allowed_versions (set[BaseModelVersionEnum]): The set of allowed base model versions.
diffusers_model_name (str): The diffusers model name (on Hugging Face Hub) to check.
revision (str, optional): The model revision (branch or commit hash). Defaults to "main".
Raises:
ValueError: If the model has an unsupported version.
"""
version = get_base_model_version(diffusers_model_name, revision, local_files_only)
if version not in allowed_versions:
raise ValueError(
f"Model '{diffusers_model_name}' (revision='{revision}') has an unsupported version: '{version.name}'. "
f"Supported versions: {[v.name for v in allowed_versions]}."
)
================================================
FILE: src/invoke_training/_shared/stable_diffusion/checkpoint_utils.py
================================================
from pathlib import Path
import torch
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer
def save_sdxl_diffusers_unet_checkpoint(
checkpoint_path: Path | str, unet: UNet2DConditionModel, save_dtype: torch.dtype
):
# Record original device and dtype so that we can restore it afterward.
model_list = [unet]
original_devices = [model.device for model in model_list]
original_dtypes = [model.dtype for model in model_list]
# Save UNet.
unet.to(dtype=save_dtype)
unet.save_pretrained(Path(checkpoint_path) / "unet")
# Restore original device/dtype.
for model, device, dtype in zip(model_list, original_devices, original_dtypes, strict=True):
model.to(device=device, dtype=dtype)
def save_sdxl_diffusers_checkpoint(
checkpoint_path: Path | str,
vae: AutoencoderKL,
text_encoder_1: CLIPTextModel,
text_encoder_2: CLIPTextModel,
tokenizer_1: CLIPTokenizer,
tokenizer_2: CLIPTokenizer,
noise_scheduler: DDPMScheduler,
unet: UNet2DConditionModel,
save_dtype: torch.dtype,
):
# Record original device and dtype so that we can restore it afterward.
# TODO(ryand): This method of restoring original device/dtype is a bit naive. It does not handle mixed precisions
# within a model, and results in a loss of precision if the save_dtype is lower precision than the model dtype. We
# may need to revisit this.
model_list = [vae, text_encoder_1, text_encoder_2, unet]
original_devices = [model.device for model in model_list]
original_dtypes = [model.dtype for model in model_list]
# Create pipeline.
pipeline = StableDiffusionXLPipeline(
vae=vae,
text_encoder=text_encoder_1,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer_1,
tokenizer_2=tokenizer_2,
unet=unet,
scheduler=noise_scheduler,
)
pipeline = pipeline.to(dtype=save_dtype)
# Save pipeline.
pipeline.save_pretrained(checkpoint_path)
# Restore original device/dtype.
for model, device, dtype in zip(model_list, original_devices, original_dtypes, strict=True):
model.to(device=device, dtype=dtype)
================================================
FILE: src/invoke_training/_shared/stable_diffusion/lora_checkpoint_utils.py
================================================
import os
from pathlib import Path
import peft
import torch
from diffusers import UNet2DConditionModel
from transformers import CLIPTextModel
from invoke_training._shared.checkpoints.lora_checkpoint_utils import (
_convert_peft_models_to_kohya_state_dict,
_convert_peft_state_dict_to_kohya_state_dict,
load_multi_model_peft_checkpoint,
save_multi_model_peft_checkpoint,
)
from invoke_training._shared.checkpoints.serialization import save_state_dict
# Copied from https://github.com/huggingface/peft/blob/8665e2b5719faa4e4b91749ddec09442927b53e0/examples/stable_diffusion/train_dreambooth.py#L49C1-L65C87
UNET_TARGET_MODULES = [
"to_q",
"to_k",
"to_v",
"proj",
"proj_in",
"proj_out",
"conv",
"conv1",
"conv2",
"conv_shortcut",
"to_out.0",
"time_emb_proj",
"ff.net.2",
]
TEXT_ENCODER_TARGET_MODULES = ["fc1", "fc2", "q_proj", "k_proj", "v_proj", "out_proj"]
# Module lists copied from diffusers training script.
# These module lists will produce lighter, less expressive, LoRA models than the non-light versions.
UNET_TARGET_MODULES_LIGHT = ["to_k", "to_q", "to_v", "to_out.0"]
TEXT_ENCODER_TARGET_MODULES_LIGHT = ["q_proj", "k_proj", "v_proj", "out_proj"]
SD_PEFT_UNET_KEY = "unet"
SD_PEFT_TEXT_ENCODER_KEY = "text_encoder"
SDXL_PEFT_UNET_KEY = "unet"
SDXL_PEFT_TEXT_ENCODER_1_KEY = "text_encoder_1"
SDXL_PEFT_TEXT_ENCODER_2_KEY = "text_encoder_2"
SD_KOHYA_UNET_KEY = "lora_unet"
SD_KOHYA_TEXT_ENCODER_KEY = "lora_te"
SDXL_KOHYA_UNET_KEY = "lora_unet"
SDXL_KOHYA_TEXT_ENCODER_1_KEY = "lora_te1"
SDXL_KOHYA_TEXT_ENCODER_2_KEY = "lora_te2"
SD_PEFT_TO_KOHYA_KEYS = {
SD_PEFT_UNET_KEY: SD_KOHYA_UNET_KEY,
SD_PEFT_TEXT_ENCODER_KEY: SD_KOHYA_TEXT_ENCODER_KEY,
}
SDXL_PEFT_TO_KOHYA_KEYS = {
SDXL_PEFT_UNET_KEY: SDXL_KOHYA_UNET_KEY,
SDXL_PEFT_TEXT_ENCODER_1_KEY: SDXL_KOHYA_TEXT_ENCODER_1_KEY,
SDXL_PEFT_TEXT_ENCODER_2_KEY: SDXL_KOHYA_TEXT_ENCODER_2_KEY,
}
def save_sd_peft_checkpoint(
checkpoint_dir: Path | str, unet: peft.PeftModel | None, text_encoder: peft.PeftModel | None
):
models = {}
if unet is not None:
models[SD_PEFT_UNET_KEY] = unet
if text_encoder is not None:
models[SD_PEFT_TEXT_ENCODER_KEY] = text_encoder
save_multi_model_peft_checkpoint(checkpoint_dir=checkpoint_dir, models=models)
def load_sd_peft_checkpoint(
checkpoint_dir: Path | str, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, is_trainable: bool = False
):
models = load_multi_model_peft_checkpoint(
checkpoint_dir=checkpoint_dir,
models={SD_PEFT_UNET_KEY: unet, SD_PEFT_TEXT_ENCODER_KEY: text_encoder},
is_trainable=is_trainable,
raise_if_subdir_missing=False,
)
return models[SD_PEFT_UNET_KEY], models[SD_PEFT_TEXT_ENCODER_KEY]
def save_sdxl_peft_checkpoint(
checkpoint_dir: Path | str,
unet: peft.PeftModel | None,
text_encoder_1: peft.PeftModel | None,
text_encoder_2: peft.PeftModel | None,
):
models = {}
if unet is not None:
models[SDXL_PEFT_UNET_KEY] = unet
if text_encoder_1 is not None:
models[SDXL_PEFT_TEXT_ENCODER_1_KEY] = text_encoder_1
if text_encoder_2 is not None:
models[SDXL_PEFT_TEXT_ENCODER_2_KEY] = text_encoder_2
save_multi_model_peft_checkpoint(checkpoint_dir=checkpoint_dir, models=models)
def load_sdxl_peft_checkpoint(
checkpoint_dir: Path | str,
unet: UNet2DConditionModel,
text_encoder_1: CLIPTextModel,
text_encoder_2: CLIPTextModel,
is_trainable: bool = False,
):
models = load_multi_model_peft_checkpoint(
checkpoint_dir=checkpoint_dir,
models={
SDXL_PEFT_UNET_KEY: unet,
SDXL_PEFT_TEXT_ENCODER_1_KEY: text_encoder_1,
SDXL_PEFT_TEXT_ENCODER_2_KEY: text_encoder_2,
},
is_trainable=is_trainable,
raise_if_subdir_missing=False,
)
return models[SDXL_PEFT_UNET_KEY], models[SDXL_PEFT_TEXT_ENCODER_1_KEY], models[SDXL_PEFT_TEXT_ENCODER_2_KEY]
def save_sd_kohya_checkpoint(checkpoint_path: Path, unet: peft.PeftModel | None, text_encoder: peft.PeftModel | None):
kohya_prefixes = []
models = []
for kohya_prefix, peft_model in zip([SD_KOHYA_UNET_KEY, SD_KOHYA_TEXT_ENCODER_KEY], [unet, text_encoder]):
if peft_model is not None:
kohya_prefixes.append(kohya_prefix)
models.append(peft_model)
kohya_state_dict = _convert_peft_models_to_kohya_state_dict(kohya_prefixes=kohya_prefixes, models=models)
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
save_state_dict(kohya_state_dict, checkpoint_path)
def save_sdxl_kohya_checkpoint(
checkpoint_path: Path,
unet: peft.PeftModel | None,
text_encoder_1: peft.PeftModel | None,
text_encoder_2: peft.PeftModel | None,
):
kohya_prefixes = []
models = []
for kohya_prefix, peft_model in zip(
[SDXL_KOHYA_UNET_KEY, SDXL_KOHYA_TEXT_ENCODER_1_KEY, SDXL_KOHYA_TEXT_ENCODER_2_KEY],
[unet, text_encoder_1, text_encoder_2],
):
if peft_model is not None:
kohya_prefixes.append(kohya_prefix)
models.append(peft_model)
kohya_state_dict = _convert_peft_models_to_kohya_state_dict(kohya_prefixes=kohya_prefixes, models=models)
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
save_state_dict(kohya_state_dict, checkpoint_path)
def convert_sd_peft_checkpoint_to_kohya_state_dict(
in_checkpoint_dir: Path,
out_checkpoint_file: Path,
dtype: torch.dtype = torch.float32,
) -> dict[str, torch.Tensor]:
"""Convert SD v1 or SDXL PEFT models to a Kohya-format LoRA state dict."""
# Get the immediate subdirectories of the checkpoint directory. We assume that each subdirectory is a PEFT model.
peft_model_dirs = os.listdir(in_checkpoint_dir)
peft_model_dirs = [in_checkpoint_dir / d for d in peft_model_dirs] # Convert to Path objects.
peft_model_dirs = [d for d in peft_model_dirs if d.is_dir()] # Filter out non-directories.
if len(peft_model_dirs) == 0:
raise ValueError(f"No checkpoint files found in directory '{in_checkpoint_dir}'.")
kohya_state_dict = {}
for peft_model_dir in peft_model_dirs:
if peft_model_dir.name in SD_PEFT_TO_KOHYA_KEYS:
kohya_prefix = SD_PEFT_TO_KOHYA_KEYS[peft_model_dir.name]
elif peft_model_dir.name in SDXL_PEFT_TO_KOHYA_KEYS:
kohya_prefix = SDXL_PEFT_TO_KOHYA_KEYS[peft_model_dir.name]
else:
raise ValueError(f"Unrecognized checkpoint directory: '{peft_model_dir}'.")
# Note: This logic to load the LoraConfig and weights directly is based on how it is done here:
# https://github.com/huggingface/peft/blob/8665e2b5719faa4e4b91749ddec09442927b53e0/src/peft/peft_model.py#L672-L689
# This may need to be updated in the future to support other adapter types (LoKr, LoHa, etc.).
# Also, I could see this interface breaking in the future.
lora_config = peft.LoraConfig.from_pretrained(peft_model_dir)
lora_weights = peft.utils.load_peft_weights(peft_model_dir, device="cpu")
kohya_state_dict.update(
_convert_peft_state_dict_to_kohya_state_dict(
lora_config=lora_config, peft_state_dict=lora_weights, prefix=kohya_prefix, dtype=dtype
)
)
save_state_dict(kohya_state_dict, out_checkpoint_file)
================================================
FILE: src/invoke_training/_shared/stable_diffusion/min_snr_weighting.py
================================================
import torch
from diffusers import DDPMScheduler
def compute_snr(noise_scheduler: DDPMScheduler, timesteps: torch.Tensor):
"""
Computes SNR.
Adapted from:
https://github.com/huggingface/diffusers/blob/ea9dc3fa90c70c7cd825ca2346a31153e08b5367/src/diffusers/training_utils.py#L40
Which was originally based on:
https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
"""
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = alphas_cumprod**0.5
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
# Expand the tensors.
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
# Compute SNR.
snr = (alpha / sigma) ** 2
return snr
================================================
FILE: src/invoke_training/_shared/stable_diffusion/model_loading_utils.py
================================================
import logging
import os
import typing
from enum import Enum
import torch
from diffusers import (
AutoencoderKL,
DDPMScheduler,
StableDiffusionPipeline,
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
from transformers import CLIPTextModel, CLIPTokenizer
from invoke_training._shared.checkpoints.serialization import load_state_dict
HF_VARIANT_FALLBACKS = [None, "fp16"]
class PipelineVersionEnum(Enum):
SD = "SD"
SDXL = "SDXL"
def load_pipeline(
logger: logging.Logger,
model_name_or_path: str,
pipeline_version: PipelineVersionEnum,
torch_dtype: torch.dtype = None,
variant: str | None = None,
) -> typing.Union[StableDiffusionPipeline, StableDiffusionXLPipeline]:
"""Load a Stable Diffusion pipeline from disk.
Args:
model_name_or_path (str): The name or path of the pipeline to load. Can be in diffusers format, or a single
stable diffusion checkpoint file. (E.g. 'runwayml/stable-diffusion-v1-5',
'stabilityai/stable-diffusion-xl-base-1.0', '/path/to/realisticVisionV51_v51VAE.safetensors', etc. )
pipeline_version (PipelineVersionEnum): The pipeline version.
variant (str | None): The Hugging Face Hub variant. Only applies if `model_name_or_path` is a HF Hub model name.
Returns:
typing.Union[StableDiffusionPipeline, StableDiffusionXLPipeline]: The loaded pipeline.
"""
if pipeline_version == PipelineVersionEnum.SD:
pipeline_class = StableDiffusionPipeline
elif pipeline_version == PipelineVersionEnum.SDXL:
pipeline_class = StableDiffusionXLPipeline
else:
raise ValueError(f"Unsupported pipeline_version: '{pipeline_version}'.")
if os.path.isfile(model_name_or_path):
return pipeline_class.from_single_file(
model_name_or_path,
torch_dtype=torch_dtype,
safety_checker=None,
feature_extractor=None,
)
return from_pretrained_with_variant_fallback(
logger=logger,
model_class=pipeline_class,
model_name_or_path=model_name_or_path,
torch_dtype=torch_dtype,
variant=variant,
# kwargs
safety_checker=None,
requires_safety_checker=False,
)
ModelT = typing.TypeVar("ModelT")
def from_pretrained_with_variant_fallback(
logger: logging.Logger,
model_class: typing.Type[ModelT],
model_name_or_path: str,
torch_dtype: torch.dtype | None = None,
variant: str | None = None,
**kwargs,
) -> ModelT:
"""A wrapper for .from_pretrained() that tries multiple variants if the initial one fails."""
variants_to_try = [variant] + [v for v in HF_VARIANT_FALLBACKS if v != variant]
model: ModelT | None = None
for variant_to_try in variants_to_try:
if variant_to_try != variant:
logger.warning(f"Trying fallback variant '{variant_to_try}'.")
try:
model = model_class.from_pretrained(
model_name_or_path,
torch_dtype=torch_dtype,
variant=variant_to_try,
**kwargs,
)
except (OSError, ValueError) as e:
error_str = str(e)
if "no file named" in error_str or "no such modeling files are available" in error_str:
# Ok; we'll try the variant fallbacks.
logger.warning(f"Failed to load '{model_name_or_path}' with variant '{variant_to_try}'. Error: {e}.")
else:
raise
if model is not None:
break
if model is None:
raise RuntimeError(f"Failed to load model '{model_name_or_path}'.")
return model
def load_models_sd(
logger: logging.Logger,
model_name_or_path: str,
hf_variant: str | None = None,
base_embeddings: dict[str, str] = None,
dtype: torch.dtype | None = None,
) -> tuple[CLIPTokenizer, DDPMScheduler, CLIPTextModel, AutoencoderKL, UNet2DConditionModel]:
"""Load all models required for training from disk, transfer them to the
target training device and cast their weight dtypes.
"""
base_embeddings = base_embeddings or {}
pipeline: StableDiffusionPipeline = load_pipeline(
logger=logger,
model_name_or_path=model_name_or_path,
pipeline_version=PipelineVersionEnum.SD,
variant=hf_variant,
)
for token, embedding_path in base_embeddings.items():
pipeline.load_textual_inversion(embedding_path, token=token)
# Extract sub-models from the pipeline.
tokenizer: CLIPTokenizer = pipeline.tokenizer
text_encoder: CLIPTextModel = pipeline.text_encoder
vae: AutoencoderKL = pipeline.vae
unet: UNet2DConditionModel = pipeline.unet
noise_scheduler = DDPMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
clip_sample=False,
steps_offset=1,
)
# Disable gradient calculation for model weights to save memory.
text_encoder.requires_grad_(False)
vae.requires_grad_(False)
unet.requires_grad_(False)
if dtype is not None:
text_encoder = text_encoder.to(dtype=dtype)
vae = vae.to(dtype=dtype)
unet = unet.to(dtype=dtype)
# Put models in 'eval' mode.
text_encoder.eval()
vae.eval()
unet.eval()
return tokenizer, noise_scheduler, text_encoder, vae, unet
def load_models_sdxl(
logger: logging.Logger,
model_name_or_path: str,
hf_variant: str | None = None,
vae_model: str | None = None,
base_embeddings: dict[str, str] = None,
dtype: torch.dtype | None = None,
) -> tuple[
CLIPTokenizer,
CLIPTokenizer,
DDPMScheduler,
CLIPTextModel,
CLIPTextModel,
AutoencoderKL,
UNet2DConditionModel,
]:
"""Load all models required for training, transfer them to the target training device and cast their weight
dtypes.
"""
base_embeddings = base_embeddings or {}
pipeline: StableDiffusionXLPipeline = load_pipeline(
logger=logger,
model_name_or_path=model_name_or_path,
pipeline_version=PipelineVersionEnum.SDXL,
variant=hf_variant,
)
for token, embedding_path in base_embeddings.items():
state_dict = load_state_dict(embedding_path)
pipeline.load_textual_inversion(
state_dict["clip_l"],
token=token,
text_encoder=pipeline.text_encoder,
tokenizer=pipeline.tokenizer,
)
pipeline.load_textual_inversion(
state_dict["clip_g"],
token=token,
text_encoder=pipeline.text_encoder_2,
tokenizer=pipeline.tokenizer_2,
)
# Extract sub-models from the pipeline.
tokenizer_1: CLIPTokenizer = pipeline.tokenizer
tokenizer_2: CLIPTokenizer = pipeline.tokenizer_2
text_encoder_1: CLIPTextModel = pipeline.text_encoder
text_encoder_2: CLIPTextModel = pipeline.text_encoder_2
vae: AutoencoderKL = pipeline.vae
unet: UNet2DConditionModel = pipeline.unet
noise_scheduler = DDPMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
num_train_timesteps=1000,
clip_sample=False,
steps_offset=1,
)
if vae_model is not None:
vae: AutoencoderKL = AutoencoderKL.from_pretrained(vae_model)
# Disable gradient calculation for model weights to save memory.
text_encoder_1.requires_grad_(False)
text_encoder_2.requires_grad_(False)
vae.requires_grad_(False)
unet.requires_grad_(False)
if dtype is not None:
text_encoder_1 = text_encoder_1.to(dtype=dtype)
text_encoder_2 = text_encoder_2.to(dtype=dtype)
vae = vae.to(dtype=dtype)
unet = unet.to(dtype=dtype)
# Put models in 'eval' mode.
text_encoder_1.eval()
text_encoder_2.eval()
vae.eval()
unet.eval()
return tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, unet
================================================
FILE: src/invoke_training/_shared/stable_diffusion/textual_inversion.py
================================================
import logging
import torch
from accelerate import Accelerator
from transformers import CLIPTextModel, CLIPTokenizer, PreTrainedTokenizer
from invoke_training._shared.checkpoints.serialization import load_state_dict
def _expand_placeholder_token(placeholder_token: str, num_vectors: int = 1) -> list[str]:
"""Expand a placeholder token into a list of placeholder tokens based on the number of embedding vectors being
trained.
"""
placeholder_tokens = [placeholder_token]
if num_vectors < 1:
raise ValueError(f"num_vectors must be >1, but is '{num_vectors}'.")
# Add dummy placeholder tokens if num_vectors > 1.
for i in range(1, num_vectors):
placeholder_tokens.append(f"{placeholder_token}_{i}")
return placeholder_tokens
def _add_tokens_to_tokenizer(placeholder_tokens: list[str], tokenizer: PreTrainedTokenizer):
"""Add new tokens to a tokenizer.
Raises:
ValueError: Raises if the tokenizer already contains one of the tokens in `placeholder_tokens`.
"""
num_added_tokens = tokenizer.add_tokens(placeholder_tokens)
if num_added_tokens != len(placeholder_tokens):
raise ValueError(
f"The tokenizer already contains one of the tokens in '{placeholder_tokens}'. Please pass a different"
" 'placeholder_token' that is not already in the tokenizer."
)
def expand_placeholders_in_caption(caption: str, tokenizer: CLIPTokenizer) -> str:
"""Expand any multi-vector placeholder tokens in the caption.
For example, "a dog in the style of my_placeholder", could get expanded to "a dog in the style of my_placeholder
my_placeholder_1 my_placeholder_2".
This implementation is based on
https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders/textual_inversion.py#L144. This logic gets
applied automatically when running a full diffusers text-to-image pipeline.
"""
tokens = tokenizer.tokenize(caption)
unique_tokens = set(tokens)
for token in unique_tokens:
if token in tokenizer.added_tokens_encoder:
replacement = token
i = 1
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
replacement += f" {token}_{i}"
i += 1
if replacement != token:
# If the replacement is different from the original token, then we double check that the replacement
# isn't already in the caption. If the replacement is already in the caption, this probably means that
# someone didn't realize that placeholder expansion is handled here.
assert replacement not in caption
caption = caption.replace(token, replacement)
return caption
def initialize_placeholder_tokens_from_initializer_token(
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
initializer_token: str,
placeholder_token: str,
num_vectors: int,
logger: logging.Logger,
) -> tuple[list[str], list[int]]:
# Convert the initializer_token to a token id.
initializer_token_ids = tokenizer.encode(initializer_token, add_special_tokens=False)
if len(initializer_token_ids) > 1:
logger.warning(
f"The initializer_token '{initializer_token}' gets tokenized to {len(initializer_token_ids)} tokens. "
"Only the first token will be used. It is recommended to choose a different initializer_token that maps to "
"a single token."
)
initializer_token_id = initializer_token_ids[0]
# Expand the tokenizer / text_encoder to include the placeholder tokens.
placeholder_tokens = _expand_placeholder_token(placeholder_token, num_vectors=num_vectors)
_add_tokens_to_tokenizer(placeholder_tokens, tokenizer)
# Resize the token embeddings as we have added new special tokens to the tokenizer.
text_encoder.resize_token_embeddings(len(tokenizer))
placeholder_token_ids = tokenizer.convert_tokens_to_ids(placeholder_tokens)
# convert_tokens_to_ids returns a `int | list[int]` type, but since we pass in a list it should always return a
# list.
assert isinstance(placeholder_token_ids, list)
# Initialize the newly-added placeholder token(s) with the embeddings of the initializer token.
token_embeds = text_encoder.get_input_embeddings().weight.data
with torch.no_grad():
for placeholder_token_id in placeholder_token_ids:
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id].clone()
return placeholder_tokens, placeholder_token_ids
def initialize_placeholder_tokens_from_initial_phrase(
tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, initial_phrase: str, placeholder_token: str
) -> tuple[list[str], list[int]]:
# Convert the initial_phrase to token ids.
initial_token_ids = tokenizer.encode(initial_phrase, add_special_tokens=False)
# Expand the tokenizer / text_encoder to include one placeholder token for each token in the initial_phrase.
placeholder_tokens = _expand_placeholder_token(placeholder_token, num_vectors=len(initial_token_ids))
_add_tokens_to_tokenizer(placeholder_tokens, tokenizer)
# Resize the token embeddings as we have added new special tokens to the tokenizer.
text_encoder.resize_token_embeddings(len(tokenizer))
placeholder_token_ids = tokenizer.convert_tokens_to_ids(placeholder_tokens)
# convert_tokens_to_ids returns a `int | list[int]` type, but since we pass in a list it should always return a
# list.
assert isinstance(placeholder_token_ids, list)
# Initialize the newly-added placeholder token(s) with the embeddings of the initial phrase.
token_embeds = text_encoder.get_input_embeddings().weight.data
with torch.no_grad():
for placeholder_token_id, initial_token_id in zip(placeholder_token_ids, initial_token_ids):
token_embeds[placeholder_token_id] = token_embeds[initial_token_id].clone()
return placeholder_tokens, placeholder_token_ids
def initialize_placeholder_tokens_from_initial_embedding(
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
initial_embedding_file: str,
placeholder_token: str,
num_vectors: int,
) -> tuple[list[str], list[int]]:
# Expand the tokenizer / text_encoder to include the placeholder tokens.
placeholder_tokens = _expand_placeholder_token(placeholder_token, num_vectors=num_vectors)
_add_tokens_to_tokenizer(placeholder_tokens, tokenizer)
# Resize the token embeddings as we have added new special tokens to the tokenizer.
text_encoder.resize_token_embeddings(len(tokenizer))
state_dict = load_state_dict(initial_embedding_file)
if placeholder_token not in state_dict:
raise ValueError(
f"The initial embedding at '{initial_embedding_file}' does not contain an embedding for placeholder token "
f"'{placeholder_token}'."
)
embeddings = state_dict[placeholder_token]
if embeddings.shape[0] != len(placeholder_tokens):
raise ValueError(
f"The number of initial embeddings in '{initial_embedding_file}' ({embeddings.shape[0]}) does not match "
f"the expected number of placeholder tokens ({len(placeholder_tokens)})."
)
placeholder_token_ids = tokenizer.convert_tokens_to_ids(placeholder_tokens)
# convert_tokens_to_ids returns a `int | list[int]` type, but since we pass in a list it should always return a
# list.
assert isinstance(placeholder_token_ids, list)
# Initialize the newly-added placeholder token(s) with the loaded embeddings.
token_embeds = text_encoder.get_input_embeddings().weight.data
with torch.no_grad():
for i, token_id in enumerate(placeholder_token_ids):
token_embeds[token_id] = embeddings[i].clone()
return placeholder_tokens, placeholder_token_ids
def restore_original_embeddings(
tokenizer: CLIPTokenizer,
placeholder_token_ids: list[int],
accelerator: Accelerator,
text_encoder: CLIPTextModel,
orig_embeds_params: torch.Tensor,
):
"""Restore the text_encoder embeddings that we are not actively training to make sure they don't change.
TODO(ryand): Look into whether this is actually necessary if we set requires_grad correctly.
"""
index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
index_updates = ~index_no_updates
with torch.no_grad():
unwrapped_text_encoder = accelerator.unwrap_model(text_encoder)
unwrapped_text_encoder.get_input_embeddings().weight[index_no_updates] = orig_embeds_params[index_no_updates]
target_std = unwrapped_text_encoder.get_input_embeddings().weight[index_no_updates].std()
new_embeddings = unwrapped_text_encoder.get_input_embeddings().weight[index_updates]
target_over_new_std = target_std / new_embeddings.std()
# Scale the new embeddings towards the target embeddings. Raise to the 0.1 power to avoid large changes.
new_embeddings = new_embeddings * (target_over_new_std**0.1)
unwrapped_text_encoder.get_input_embeddings().weight[index_updates] = new_embeddings
================================================
FILE: src/invoke_training/_shared/stable_diffusion/tokenize_captions.py
================================================
import torch
from transformers import CLIPTokenizer
from invoke_training._shared.stable_diffusion.textual_inversion import expand_placeholders_in_caption
def tokenize_captions(tokenizer: CLIPTokenizer, captions: list[str]) -> torch.Tensor:
"""Tokenize a list of caption.
Args:
tokenizer (CLIPTokenizer): The tokenizer.
caption (str): The caption.
Returns:
torch.Tensor: The token IDs.
"""
caption_token_ids = []
for caption in captions:
caption = expand_placeholders_in_caption(caption, tokenizer)
input = tokenizer(
caption,
max_length=tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
caption_token_ids.append(input.input_ids[0, ...])
caption_token_ids = torch.stack(caption_token_ids)
return caption_token_ids
================================================
FILE: src/invoke_training/_shared/stable_diffusion/validation.py
================================================
import logging
import os
import numpy as np
import torch
import torch.utils.data
from accelerate import Accelerator
from accelerate.hooks import remove_hook_from_module
from diffusers import (
AutoencoderKL,
DDPMScheduler,
StableDiffusionPipeline,
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
from transformers import CLIPTextModel, CLIPTokenizer
from invoke_training._shared.data.utils.resolution import Resolution
from invoke_training.pipelines.callbacks import PipelineCallbacks, ValidationImage, ValidationImages
from invoke_training.pipelines.stable_diffusion.lora.config import SdLoraConfig
from invoke_training.pipelines.stable_diffusion_xl.lora.config import SdxlLoraConfig
def generate_validation_images_sd( # noqa: C901
epoch: int,
step: int,
out_dir: str,
accelerator: Accelerator,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
noise_scheduler: DDPMScheduler,
unet: UNet2DConditionModel,
config: SdLoraConfig,
logger: logging.Logger,
callbacks: list[PipelineCallbacks] | None = None,
):
"""Generate validation images for the purpose of tracking image generation behaviour on fixed prompts throughout
training.
"""
# Record original model devices so that we can restore this state after running the pipeline with CPU model
# offloading.
unet_device = unet.device
vae_device = vae.device
text_encoder_device = text_encoder.device
# Create pipeline.
pipeline = StableDiffusionPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=noise_scheduler,
safety_checker=None,
feature_extractor=None,
# TODO(ryand): Add safety checker support.
requires_safety_checker=False,
)
if config.enable_cpu_offload_during_validation:
pipeline.enable_model_cpu_offload(accelerator.device.index or 0)
else:
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
validation_resolution = Resolution.parse(config.data_loader.resolution)
validation_images = ValidationImages(images=[], epoch=epoch, step=step)
validation_step_dir = os.path.join(out_dir, "validation", f"epoch_{epoch:0>8}-step_{step:0>8}")
logger.info(f"Generating validation images ({validation_step_dir}).")
# Run inference.
with torch.no_grad():
for prompt_idx in range(len(config.validation_prompts)):
positive_prompt = config.validation_prompts[prompt_idx]
negative_prompt = None
if config.negative_validation_prompts is not None:
negative_prompt = config.negative_validation_prompts[prompt_idx]
logger.info(f"Validation prompt {prompt_idx}, pos: '{positive_prompt}', neg: '{negative_prompt or ''}'")
generator = torch.Generator(device=accelerator.device)
if config.seed is not None:
generator = generator.manual_seed(config.seed)
images = []
for _ in range(config.num_validation_images_per_prompt):
with accelerator.autocast():
images.append(
pipeline(
positive_prompt,
num_inference_steps=30,
generator=generator,
height=validation_resolution.height,
width=validation_resolution.width,
negative_prompt=negative_prompt,
).images[0]
)
# Save images to disk.
validation_prompt_dir = os.path.join(validation_step_dir, f"prompt_{prompt_idx:0>4}")
os.makedirs(validation_prompt_dir)
for image_idx, image in enumerate(images):
image_path = os.path.join(validation_prompt_dir, f"{image_idx:0>4}.jpg")
validation_images.images.append(
ValidationImage(file_path=image_path, prompt=positive_prompt, image_idx=image_idx)
)
image.save(image_path)
# Log images to trackers. Currently, only tensorboard is supported.
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images(
f"validation (prompt {prompt_idx})",
np_images,
step,
dataformats="NHWC",
)
del pipeline
torch.cuda.empty_cache()
# Remove hooks from models.
# HACK(ryand): Hooks get added when calling `pipeline.enable_model_cpu_offload(...)`, but `StableDiffusionPipeline`
# does not offer a way to clean them up so we have to do this manually.
for model in [unet, vae, text_encoder]:
remove_hook_from_module(model)
# Restore models to original devices.
unet.to(unet_device)
vae.to(vae_device)
text_encoder.to(text_encoder_device)
# Run callbacks.
if callbacks is not None:
for cb in callbacks:
cb.on_save_validation_images(images=validation_images)
def generate_validation_images_sdxl( # noqa: C901
epoch: int,
step: int,
out_dir: str,
accelerator: Accelerator,
vae: AutoencoderKL,
text_encoder_1: CLIPTextModel,
text_encoder_2: CLIPTextModel,
tokenizer_1: CLIPTokenizer,
tokenizer_2: CLIPTokenizer,
noise_scheduler: DDPMScheduler,
unet: UNet2DConditionModel,
config: SdxlLoraConfig,
logger: logging.Logger,
callbacks: list[PipelineCallbacks] | None = None,
):
"""Generate validation images for the purpose of tracking image generation behaviour on fixed prompts throughout
training.
"""
# Record original model devices so that we can restore this state after running the pipeline with CPU model
# offloading.
unet_device = unet.device
vae_device = vae.device
text_encoder_1_device = text_encoder_1.device
text_encoder_2_device = text_encoder_2.device
# Create pipeline.
pipeline = StableDiffusionXLPipeline(
vae=vae,
text_encoder=text_encoder_1,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer_1,
tokenizer_2=tokenizer_2,
unet=unet,
scheduler=noise_scheduler,
)
if config.enable_cpu_offload_during_validation:
pipeline.enable_model_cpu_offload(accelerator.device.index or 0)
else:
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
validation_resolution = Resolution.parse(config.data_loader.resolution)
validation_images = ValidationImages(images=[], epoch=epoch, step=step)
validation_step_dir = os.path.join(out_dir, "validation", f"epoch_{epoch:0>8}-step_{step:0>8}")
logger.info(f"Generating validation images ({validation_step_dir}).")
# Run inference.
with torch.no_grad():
for prompt_idx in range(len(config.validation_prompts)):
positive_prompt = config.validation_prompts[prompt_idx]
negative_prompt = None
if config.negative_validation_prompts is not None:
negative_prompt = config.negative_validation_prompts[prompt_idx]
logger.info(f"Validation prompt {prompt_idx}, pos: '{positive_prompt}', neg: '{negative_prompt or ''}'")
generator = torch.Generator(device=accelerator.device)
if config.seed is not None:
generator = generator.manual_seed(config.seed)
images = []
for _ in range(config.num_validation_images_per_prompt):
with accelerator.autocast():
images.append(
pipeline(
positive_prompt,
num_inference_steps=30,
generator=generator,
height=validation_resolution.height,
width=validation_resolution.width,
negative_prompt=negative_prompt,
).images[0]
)
# Save images to disk.
validation_prompt_dir = os.path.join(validation_step_dir, f"prompt_{prompt_idx:0>4}")
os.makedirs(validation_prompt_dir)
for image_idx, image in enumerate(images):
image_path = os.path.join(validation_prompt_dir, f"{image_idx:0>4}.jpg")
validation_images.images.append(
ValidationImage(file_path=image_path, prompt=positive_prompt, image_idx=image_idx)
)
image.save(image_path)
# Log images to trackers. Currently, only tensorboard is supported.
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images(
f"validation (prompt {prompt_idx})",
np_images,
step,
dataformats="NHWC",
)
del pipeline
torch.cuda.empty_cache()
# Remove hooks from models.
# HACK(ryand): Hooks get added when calling `pipeline.enable_model_cpu_offload(...)`, but
# `StableDiffusionXLPipeline` does not offer a way to clean them up so we have to do this manually.
for model in [unet, vae, text_encoder_1, text_encoder_2]:
remove_hook_from_module(model)
# Restore models to original devices.
unet.to(unet_device)
vae.to(vae_device)
text_encoder_1.to(text_encoder_1_device)
text_encoder_2.to(text_encoder_2_device)
# Run callbacks.
if callbacks is not None:
for cb in callbacks:
cb.on_save_validation_images(images=validation_images)
================================================
FILE: src/invoke_training/_shared/tools/__init__.py
================================================
================================================
FILE: src/invoke_training/_shared/tools/generate_images.py
================================================
import os
from pathlib import Path
from typing import Optional
import torch
from tqdm import tqdm
from invoke_training._shared.data.datasets.image_pair_preference_dataset import ImagePairPreferenceDataset
from invoke_training._shared.stable_diffusion.model_loading_utils import (
PipelineVersionEnum,
load_pipeline,
)
def generate_images(
out_dir: str,
model: str,
hf_variant: str | None,
pipeline_version: PipelineVersionEnum,
prompts: list[str],
set_size: int,
num_sets: int,
height: int,
width: int,
loras: Optional[list[tuple[Path, float]]] = None,
ti_embeddings: Optional[list[str]] = None,
seed: int = 0,
torch_dtype: torch.dtype = torch.float16,
torch_device: str = "cuda",
enable_cpu_offload: bool = False,
):
"""Generate a set of images and store them in a directory. Typically used to generate a datasets for prior
preservation / regularization.
Args:
out_dir (str): The output directory to create.
model (str): The name or path of the diffusers pipeline to generate with.
sd_version (PipelineVersionEnum): The model version.
prompt (str): The prompt to generate images with.
set_size (int): The number of images in a 'set' for a given prompt.
num_sets (int): The number of 'sets' to generate for each prompt.
height (int): The output image height in pixels (recommended to match the resolution that the model was trained
with).
width (int): The output image width in pixels (recommended to match the resolution that the model was trained
with).
loras (list[tuple[Path, float]], optional): Paths to LoRA models to apply to the base model with associated
weights.
ti_embeddings (list[str], optional): Paths to TI embeddings to apply to the base model.
seed (int, optional): A seed for repeatability. Defaults to 0.
torch_dtype (torch.dtype, optional): The torch dtype. Defaults to torch.float16.
torch_device (str, optional): The torch device. Defaults to "cuda".
enable_cpu_offload (bool, optional): If True, models will be loaded onto the GPU one by one to conserve VRAM.
Defaults to False.
"""
pipeline = load_pipeline(model_name_or_path=model, pipeline_version=pipeline_version, variant=hf_variant)
loras = loras or []
for lora in loras:
lora_path, lora_scale = lora
pipeline.load_lora_weights(str(lora_path), weight_name=str(lora_path.name))
pipeline.fuse_lora(lora_scale=lora_scale)
ti_embeddings = ti_embeddings or []
for ti_embedding in ti_embeddings:
pipeline.load_textual_inversion(ti_embedding)
pipeline.to(torch_dtype=torch_dtype)
if enable_cpu_offload:
pipeline.enable_model_cpu_offload()
else:
pipeline.to(torch_device=torch_device)
pipeline.set_progress_bar_config(disable=True)
generator = torch.Generator(device=torch_device)
if seed is not None:
generator = generator.manual_seed(seed)
os.makedirs(out_dir)
metadata = []
total_images = num_sets * len(prompts) * set_size
with torch.no_grad(), tqdm(total=total_images) as pbar:
for prompt_idx in range(len(prompts)):
for set_idx in range(num_sets):
set_dir = os.path.join(out_dir, f"prompt-{prompt_idx:0>4}", f"set-{set_idx:0>4}")
os.makedirs(set_dir)
set_metadata_dict = {"prompt": prompts[prompt_idx]}
for image_idx in range(set_size):
image = pipeline(
prompts[prompt_idx],
num_inference_steps=30,
generator=generator,
height=height,
width=width,
).images[0]
image_path = os.path.join(set_dir, f"image-{image_idx}.jpg")
image.save(image_path)
set_metadata_dict[f"image_{image_idx}"] = os.path.relpath(image_path, start=out_dir)
set_metadata_dict[f"prefer_{image_idx}"] = False
pbar.update(1)
metadata.append(set_metadata_dict)
ImagePairPreferenceDataset.save_metadata(metadata=metadata, dataset_dir=out_dir)
================================================
FILE: src/invoke_training/_shared/utils/import_xformers.py
================================================
def import_xformers():
try:
import xformers # noqa: F401
except ImportError:
raise ImportError(
"xformers is not installed. Either set `xformers = False` in your training config, or install it using "
'`pip install ".[xformers]"`.'
)
================================================
FILE: src/invoke_training/_shared/utils/jsonl.py
================================================
import json
from pathlib import Path
from typing import Any
def load_jsonl(jsonl_path: Path | str) -> list[Any]:
"""Load a JSONL file."""
data = []
with open(jsonl_path) as f:
while (line := f.readline().strip()) != "":
data.append(json.loads(line))
return data
def save_jsonl(data: list[Any], jsonl_path: Path | str) -> None:
"""Save a list of objects to a JSONL file."""
with open(jsonl_path, "w") as f:
for line in data:
f.write(json.dumps(line) + "\n")
================================================
FILE: src/invoke_training/config/__init__.py
================================================
================================================
FILE: src/invoke_training/config/base_pipeline_config.py
================================================
import typing
from typing import Optional
from invoke_training.config.config_base_model import ConfigBaseModel
class BasePipelineConfig(ConfigBaseModel):
"""A base config with fields that should be inherited by all pipelines."""
type: str
seed: Optional[int] = None
"""A randomization seed for reproducible training. Set to any constant integer for consistent training results. If
set to `null`, training will be non-deterministic.
"""
base_output_dir: str
"""The output directory where the training outputs (model checkpoints, logs, intermediate predictions) will be
written. A subdirectory will be created with a timestamp for each new training run.
"""
report_to: typing.Literal["all", "tensorboard", "wandb", "comet_ml"] = "tensorboard"
"""The integration to report results and logs to. This value is passed to Hugging Face Accelerate. See
`accelerate.Accelerator.log_with` for more details.
"""
max_train_steps: int | None = None
"""Total number of training steps to perform. One training step is one gradient update.
One of `max_train_steps` or `max_train_epochs` should be set.
"""
max_train_epochs: int | None = None
"""Total number of training epochs to perform. One epoch is one pass over the entire dataset.
One of `max_train_steps` or `max_train_epochs` should be set.
"""
save_every_n_epochs: int | None = None
"""The interval (in epochs) at which to save checkpoints.
One of `save_every_n_epochs` or `save_every_n_steps` should be set.
"""
save_every_n_steps: int | None = None
"""The interval (in steps) at which to save checkpoints.
One of `save_every_n_epochs` or `save_every_n_steps` should be set.
"""
validate_every_n_epochs: int | None = None
"""The interval (in epochs) at which validation images will be generated.
One of `validate_every_n_epochs` or `validate_every_n_steps` should be set.
"""
validate_every_n_steps: int | None = None
"""The interval (in steps) at which validation images will be generated.
One of `validate_every_n_epochs` or `validate_every_n_steps` should be set.
"""
================================================
FILE: src/invoke_training/config/config_base_model.py
================================================
from pydantic import BaseModel, ConfigDict
class ConfigBaseModel(BaseModel):
"""Base model for all invoke training configuration models."""
# Configure to raise if extra fields are passed in.
model_config = ConfigDict(extra="forbid")
================================================
FILE: src/invoke_training/config/data/__init__.py
================================================
================================================
FILE: src/invoke_training/config/data/data_loader_config.py
================================================
from typing import Literal, Optional
from invoke_training.config.config_base_model import ConfigBaseModel
from invoke_training.config.data.dataset_config import (
ImageCaptionDatasetConfig,
ImageDirDatasetConfig,
)
class AspectRatioBucketConfig(ConfigBaseModel):
target_resolution: int
"""The target resolution for all aspect ratios. When generating aspect ratio buckets, the resolution of each bucket
is selected to have roughly `target_resolution * target_resolution` pixels (i.e. a square image with dimensions
equal to `target_resolution`).
"""
start_dim: int
"""Aspect ratio bucket resolutions are generated as follows:
- Iterate over 'first' dimension values from `start_dim` to `end_dim` in steps of size `divisible_by`.
- Calculate the 'second' dimension to be as close as possible to the total number of pixels in `target_resolution`,
while still being divisible by `divisible_by`.
tip: Choosing aspect ratio buckets
The aspect ratio bucket resolutions are logged at the start of training with the number of images in each
bucket. Review these logs to make sure that images are being split into buckets as expected.
Highly fragmented splits (i.e. many buckets with few examples in each) can 1) limit the extent to which examples
can be shuffled, and 2) slow down training if there are many partial batches.
"""
end_dim: int
"""See explanation under
[`start_dim`][invoke_training.config.data.data_loader_config.AspectRatioBucketConfig.start_dim].
"""
divisible_by: int
"""See explanation under
[`start_dim`][invoke_training.config.data.data_loader_config.AspectRatioBucketConfig.start_dim].
"""
class ImageCaptionSDDataLoaderConfig(ConfigBaseModel):
type: Literal["IMAGE_CAPTION_SD_DATA_LOADER"] = "IMAGE_CAPTION_SD_DATA_LOADER"
dataset: ImageCaptionDatasetConfig
aspect_ratio_buckets: AspectRatioBucketConfig | None = None
resolution: int | tuple[int, int] = 512
"""The resolution for input images. Either a scalar integer representing the square resolution height and width, or
a (height, width) tuple. All of the images in the dataset will be resized to this resolution unless the
`aspect_ratio_buckets` config is set.
"""
center_crop: bool = True
"""If True, input images will be center-cropped to the target resolution.
If False, input images will be randomly cropped to the target resolution.
"""
random_flip: bool = False
"""Whether random flip augmentations should be applied to input images.
"""
caption_prefix: str | None = None
"""A prefix that will be prepended to all captions. If None, no prefix will be added.
"""
dataloader_num_workers: int = 0
"""Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.
"""
class ImageCaptionFluxDataLoaderConfig(ConfigBaseModel):
type: Literal["IMAGE_CAPTION_FLUX_DATA_LOADER"] = "IMAGE_CAPTION_FLUX_DATA_LOADER"
dataset: ImageCaptionDatasetConfig
aspect_ratio_buckets: AspectRatioBucketConfig | None = None
resolution: int | tuple[int, int] = 512
"""The resolution for input images. Either a scalar integer representing the square resolution height and width, or
a (height, width) tuple. All of the images in the dataset will be resized to this resolution unless the
`aspect_ratio_buckets` config is set.
"""
center_crop: bool = True
"""If True, input images will be center-cropped to the target resolution.
If False, input images will be randomly cropped to the target resolution.
"""
random_flip: bool = False
"""Whether random flip augmentations should be applied to input images.
"""
caption_prefix: str | None = None
"""A prefix that will be prepended to all captions. If None, no prefix will be added.
"""
dataloader_num_workers: int = 0
"""Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.
"""
class DreamboothSDDataLoaderConfig(ConfigBaseModel):
type: Literal["DREAMBOOTH_SD_DATA_LOADER"] = "DREAMBOOTH_SD_DATA_LOADER"
instance_caption: str
class_caption: Optional[str] = None
instance_dataset: ImageDirDatasetConfig
class_dataset: Optional[ImageDirDatasetConfig] = None
class_data_loss_weight: float = 1.0
"""The loss weight applied to class dataset examples. Instance dataset examples have an implicit loss weight of 1.0.
"""
aspect_ratio_buckets: AspectRatioBucketConfig | None = None
"""The aspect ratio bucketing configuration. If None, aspect ratio bucketing is disabled, and all images will be
resized to the same resolution.
"""
resolution: int | tuple[int, int] = 512
"""The resolution for input images. Either a scalar integer representing the square resolution height and width, or
a (height, width) tuple. All of the images in the dataset will be resized to this resolution unless the
`aspect_ratio_buckets` config is set.
"""
center_crop: bool = True
"""If True, input images will be center-cropped to the target resolution.
If False, input images will be randomly cropped to the target resolution.
"""
random_flip: bool = False
"""Whether random flip augmentations should be applied to input images.
"""
dataloader_num_workers: int = 0
"""Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.
"""
class TextualInversionSDDataLoaderConfig(ConfigBaseModel):
type: Literal["TEXTUAL_INVERSION_SD_DATA_LOADER"] = "TEXTUAL_INVERSION_SD_DATA_LOADER"
dataset: ImageDirDatasetConfig | ImageCaptionDatasetConfig
caption_preset: Literal["style", "object"] | None = None
caption_templates: list[str] | None = None
"""A list of caption templates with a single template argument 'slot' in each.
E.g.:
- "a photo of a {}"
- "a rendering of a {}"
- "a cropped photo of the {}"
"""
keep_original_captions: bool = False
"""If `True`, then the captions generated as a result of the `caption_preset` or `caption_templates` will be used as
prefixes for the original captions. If `False`, then the generated captions will replace the original captions.
"""
aspect_ratio_buckets: AspectRatioBucketConfig | None = None
"""The aspect ratio bucketing configuration. If None, aspect ratio bucketing is disabled, and all images will be
resized to the same resolution.
"""
resolution: int | tuple[int, int] = 512
"""The resolution for input images. Either a scalar integer representing the square resolution height and width, or
a (height, width) tuple. All of the images in the dataset will be resized to this resolution unless the
`aspect_ratio_buckets` config is set.
"""
center_crop: bool = True
"""If True, input images will be center-cropped to the target resolution.
If False, input images will be randomly cropped to the target resolution.
"""
random_flip: bool = False
"""Whether random flip augmentations should be applied to input images.
"""
shuffle_caption_delimiter: str | None = None
"""If `None`, then no caption shuffling is applied. If set, then captions are split on this delimiter and shuffled.
"""
dataloader_num_workers: int = 0
"""Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.
"""
================================================
FILE: src/invoke_training/config/data/dataset_config.py
================================================
from typing import Annotated, Literal, Optional, Union
from pydantic import Field
from invoke_training.config.config_base_model import ConfigBaseModel
class HFHubImageCaptionDatasetConfig(ConfigBaseModel):
type: Literal["HF_HUB_IMAGE_CAPTION_DATASET"] = "HF_HUB_IMAGE_CAPTION_DATASET"
dataset_name: str
"""The name of a Hugging Face dataset.
"""
dataset_config_name: Optional[str] = None
"""The Hugging Face dataset config name. Leave as None if there's only one config.
"""
hf_cache_dir: Optional[str] = None
"""The Hugging Face cache directory to use for dataset downloads.
If None, the default value will be used (usually '~/.cache/huggingface/datasets').
"""
image_column: str = "image"
"""The name of the dataset column that contains image paths.
"""
caption_column: str = "text"
"""The name of the dataset column that contains captions.
"""
class ImageCaptionJsonlDatasetConfig(ConfigBaseModel):
type: Literal["IMAGE_CAPTION_JSONL_DATASET"] = "IMAGE_CAPTION_JSONL_DATASET"
jsonl_path: str
"""The path to a JSONL file containing image paths and captions."""
image_column: str = "image"
"""The name of the dataset column that contains image paths.
"""
caption_column: str = "text"
"""The name of the dataset column that contains captions.
"""
keep_in_memory: bool = False
"""If `True`, load all images into memory on initialization so that they can be accessed quickly. If `False`, images
are loaded from disk each time they are accessed. Setting to `True` improves performance for datasets that are small
enough to be kept in memory.
"""
class ImageDirDatasetConfig(ConfigBaseModel):
type: Literal["IMAGE_DIR_DATASET"] = "IMAGE_DIR_DATASET"
dataset_dir: str
"""The directory to load images from."""
keep_in_memory: bool = False
"""If `True`, load all images into memory on initialization so that they can be accessed quickly. If `False`, images
are loaded from disk each time they are accessed. Setting to `True` improves performance for datasets that are small
enough to be kept in memory.
"""
class ImageCaptionDirDatasetConfig(ConfigBaseModel):
type: Literal["IMAGE_CAPTION_DIR_DATASET"] = "IMAGE_CAPTION_DIR_DATASET"
dataset_dir: str
"""The directory to load images from."""
keep_in_memory: bool = False
"""If `True`, load all images into memory on initialization so that they can be accessed quickly. If `False`, images
are loaded from disk each time they are accessed. Setting to `True` improves performance for datasets that are small
enough to be kept in memory.
"""
# Datasets that produce image-caption pairs.
ImageCaptionDatasetConfig = Annotated[
Union[HFHubImageCaptionDatasetConfig, ImageCaptionJsonlDatasetConfig, ImageCaptionDirDatasetConfig],
Field(discriminator="type"),
]
================================================
FILE: src/invoke_training/config/optimizer/__init__.py
================================================
================================================
FILE: src/invoke_training/config/optimizer/optimizer_config.py
================================================
import typing
from invoke_training.config.config_base_model import ConfigBaseModel
class AdamOptimizerConfig(ConfigBaseModel):
optimizer_type: typing.Literal["AdamW"] = "AdamW"
learning_rate: float = 1e-4
"""Initial learning rate to use (after the potential warmup period). Note that in some training pipelines this can
be overriden for a specific group of params: https://pytorch.org/docs/stable/optim.html#per-parameter-options
(E.g. see `text_encoder_learning_rate` and `unet_learning_rate`)
"""
beta1: float = 0.9
beta2: float = 0.999
weight_decay: float = 1e-2
epsilon: float = 1e-8
use_8bit: bool = False
"""Use an 8-bit version of the Adam optimizer. This requires the bitsandbytes library to be installed. use_8bit
reduces the VRAM usage of the optimizer, but increases the risk of issues with numerical stability.
"""
class ProdigyOptimizerConfig(ConfigBaseModel):
optimizer_type: typing.Literal["Prodigy"] = "Prodigy"
learning_rate: float = 1.0
"""The learning rate. For the Prodigy optimizer, the learning rate is adjusted dynamically. A value of 1.0 is
recommended. Note that in some training pipelines this can be overriden for a specific group of params:
https://pytorch.org/docs/stable/optim.html#per-parameter-options (E.g. see `text_encoder_learning_rate` and
`unet_learning_rate`)
"""
weight_decay: float = 0.0
use_bias_correction: bool = False
safeguard_warmup: bool = False
================================================
FILE: src/invoke_training/config/pipeline_config.py
================================================
from typing import Annotated, Union
from pydantic import Field
from invoke_training.pipelines._experimental.sd_dpo_lora.config import SdDirectPreferenceOptimizationLoraConfig
from invoke_training.pipelines.flux.lora.config import FluxLoraConfig
from invoke_training.pipelines.stable_diffusion.lora.config import SdLoraConfig
from invoke_training.pipelines.stable_diffusion.textual_inversion.config import SdTextualInversionConfig
from invoke_training.pipelines.stable_diffusion_xl.finetune.config import SdxlFinetuneConfig
from invoke_training.pipelines.stable_diffusion_xl.lora.config import SdxlLoraConfig
from invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config import (
SdxlLoraAndTextualInversionConfig,
)
from invoke_training.pipelines.stable_diffusion_xl.textual_inversion.config import SdxlTextualInversionConfig
PipelineConfig = Annotated[
Union[
FluxLoraConfig,
SdLoraConfig,
SdxlLoraConfig,
SdTextualInversionConfig,
SdxlTextualInversionConfig,
SdxlLoraAndTextualInversionConfig,
SdxlFinetuneConfig,
SdDirectPreferenceOptimizationLoraConfig,
],
Field(discriminator="type"),
]
================================================
FILE: src/invoke_training/model_merge/__init__.py
================================================
================================================
FILE: src/invoke_training/model_merge/extract_lora.py
================================================
import torch
import tqdm
from peft.peft_model import PeftModel
# All original base model weights in a PeftModel have this prefix and suffix.
PEFT_BASE_LAYER_PREFIX = "base_model.model."
PEFT_BASE_LAYER_SUFFIX = ".base_layer.weight"
def get_patched_base_weights_from_peft_model(peft_model: PeftModel) -> dict[str, torch.Tensor]:
"""Get a state_dict containing the base model weights *thath are patched* in the provided PeftModel. I.e. only
return base model weights that have associated LoRa layers, but don't return the LoRA layers.
"""
state_dict = peft_model.state_dict()
out_state_dict: dict[str, torch.Tensor] = {}
for weight_name in state_dict:
# Weights that end with ".base_layer.weight" are the original weights for LoRA layers.
if weight_name.endswith(PEFT_BASE_LAYER_SUFFIX):
# Extract the base module name.
module_name = weight_name[: -len(PEFT_BASE_LAYER_SUFFIX)]
assert module_name.startswith(PEFT_BASE_LAYER_PREFIX)
module_name = module_name[len(PEFT_BASE_LAYER_PREFIX) :]
out_state_dict[module_name] = state_dict[weight_name]
return out_state_dict
def get_state_dict_diff(
state_dict_1: dict[str, torch.Tensor], state_dict_2: dict[str, torch.Tensor]
) -> dict[str, torch.Tensor]:
"""Return the difference between two state_dicts: state_dict_1 - state_dict_2."""
return {key: state_dict_1[key] - state_dict_2[key] for key in state_dict_1}
@torch.no_grad()
def extract_lora_from_diffs(
diffs: dict[str, torch.Tensor], rank: int, clamp_quantile: float, out_dtype: torch.dtype
) -> dict[str, tuple[torch.Tensor, torch.Tensor]]:
lora_weights = {}
for lora_name, mat in tqdm.tqdm(list(diffs.items())):
# Use full precision for the intermediate calculations.
mat = mat.to(torch.float32)
is_conv2d = False
if len(mat.shape) == 4: # Conv2D
is_conv2d = True
out_dim, in_dim, kernel_h, kernel_w = mat.shape
# Reshape to (out_dim, in_dim * kernel_h * kernel_w).
mat = mat.flatten(start_dim=1)
elif len(mat.shape) == 2: # Linear
out_dim, in_dim = mat.shape
else:
raise ValueError(f"Unexpected weight shape: {mat.shape}")
# LoRA rank cannot exceed the original dimensions.
assert rank < in_dim
assert rank < out_dim
u: torch.Tensor
s: torch.Tensor
v_h: torch.Tensor
u, s, v_h = torch.linalg.svd(mat)
# Apply the Eckart-Young-Mirsky theorem.
# https://en.wikipedia.org/wiki/Low-rank_approximation#Proof_of_Eckart%E2%80%93Young%E2%80%93Mirsky_theorem_(for_Frobenius_norm)
u = u[:, :rank]
s = s[:rank]
u = u @ torch.diag(s)
v_h = v_h[:rank, :]
# At this point, u is the lora_up (a.k.a. lora_B) weight, and v_h is the lora_down (a.k.a. lora_A) weight.
# The reason we don't use more appropriate variable names is to keep memory usage low - we want the old tensors
# to get cleaned up after each operation.
# Clamp the outliers.
dist = torch.cat([u.flatten(), v_h.flatten()])
hi_val = torch.quantile(dist, clamp_quantile)
low_val = -hi_val
u = u.clamp(low_val, hi_val)
v_h = v_h.clamp(low_val, hi_val)
if is_conv2d:
u = u.reshape(out_dim, rank, 1, 1)
v_h = v_h.reshape(rank, in_dim, kernel_h, kernel_w)
u = u.to(dtype=out_dtype).contiguous()
v_h = v_h.to(dtype=out_dtype).contiguous()
lora_weights[lora_name] = (u, v_h)
return lora_weights
================================================
FILE: src/invoke_training/model_merge/merge_models.py
================================================
from typing import Literal
import torch
import tqdm
from invoke_training.model_merge.utils.normalize_weights import normalize_weights
@torch.no_grad()
def merge_models(
state_dicts: list[dict[str, torch.Tensor]], weights: list[float], merge_method: Literal["LERP", "SLERP"] = "LERP"
):
"""Merge multiple models into a single model.
Args:
state_dicts (list[dict[str, torch.Tensor]]): The state dicts to merge.
weights (list[float]): The weights for each state dict. The weights will be normalized to sum to 1.
merge_method (Literal["LERP", "SLERP"]): Merge method to use. Options:
- "LERP": Linear interpolation a.k.a. weighted sum.
- "SLERP": Spherical linear interpolation.
"""
if len(state_dicts) < 2:
raise ValueError("Must provide >=2 models to merge.")
if len(state_dicts) != len(weights):
raise ValueError("Must provide a weight for each model.")
if merge_method == "LERP":
merge_fn = lerp
elif merge_method == "SLERP":
merge_fn = slerp
else:
raise ValueError(f"Unknown merge method: {merge_method}")
normalized_weights = normalize_weights(weights)
out_state_dict: dict[str, torch.Tensor] = state_dicts[0].copy()
out_state_dict_weight = normalized_weights[0]
for state_dict, normalized_weight in zip(state_dicts[1:], normalized_weights[1:], strict=True):
if state_dict.keys() != out_state_dict.keys():
raise ValueError("State dicts must have the same keys.")
cur_pair_weights = normalize_weights([out_state_dict_weight, normalized_weight])
for key in tqdm.tqdm(out_state_dict.keys()):
out_state_dict[key] = merge_fn(out_state_dict[key], state_dict[key], cur_pair_weights[0])
# Update the weight of out_state_dict to be the sum of all state dicts merged so far.
out_state_dict_weight += normalized_weight
return out_state_dict
def lerp(a: torch.Tensor, b: torch.Tensor, weight_a: float) -> torch.Tensor:
"""Linear interpolation."""
return torch.lerp(a, b, (1.0 - weight_a))
def slerp(a: torch.Tensor, b: torch.Tensor, weight_a: float, dot_product_thres=0.9995, epsilon=1e-10):
"""Spherical linear interpolation."""
# TODO(ryand): For multi-dimensional matrices, it might be better to apply slerp on a subset of the dimensions
# (e.g. per-row), rather than treating the entire matrix as a single flattened vector.
# Normalize the vectors.
a_norm = torch.linalg.norm(a)
b_norm = torch.linalg.norm(b)
a_normalized = a / a_norm
b_normalized = b / b_norm
if a_norm < epsilon or b_norm < epsilon:
# If either vector is very small, fallback to lerp to avoid weird effects.
# TODO(ryand): Is fallback here necessary?
return lerp(a, b, weight_a)
# Dot product of the normalized vectors.
# We are effectively treating multi-dimensional tensors as flattened vectors.
dot_prod = torch.sum(a_normalized * b_normalized)
# If the absolute value of the dot product is almost 1, the vectors are ~colinear, so use lerp.
if torch.abs(dot_prod) > dot_product_thres:
return lerp(a, b, weight_a)
# Calculate initial angle between the vectors.
theta_0 = torch.acos(dot_prod)
# Angle at timestep t.
t = 1.0 - weight_a
theta_t = theta_0 * t
sin_theta_0 = torch.sin(theta_0)
sin_theta_t = torch.sin(theta_t)
s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
result = s0 * a + s1 * b
return result
================================================
FILE: src/invoke_training/model_merge/merge_tasks_to_base.py
================================================
from typing import Literal
import torch
import tqdm
from peft.utils.merge_utils import dare_linear, dare_ties, ties
@torch.no_grad()
def merge_tasks_to_base_model(
base_state_dict: dict[str, torch.Tensor],
task_state_dicts: list[dict[str, torch.Tensor]],
task_weights: list[float],
density: float = 0.2,
merge_method: Literal["TIES", "DARE_LINEAR", "DARE_TIES"] = "TIES",
) -> torch.Tensor:
"""Merge a base model with one or more task-specific models.
Args:
base_state_dict (dict[str, torch.Tensor]): The base state dict to merge with.
task_state_dicts (list[dict[str, torch.Tensor]]): A list of task-specific state dicts to merge into the base
state dict.
task_weights (list[float]): Weights for each task state dict. Weights of 1.0 for all task_state_dicts are
recommended as a starting point (e.g. [1.0, 1.0, 1.0]). The weights can be adjusted from there (e.g.
[1.0, 1.3, 1.0]). The weights are multipliers applied to the diff between each task_state_dict and the base
model.
density (float, optional): The fraction of values to preserve in the prune/trim step of DARE/TIES methods.
Should be in the range [0, 1].
merge_method (Literal["TIES", "DARE_LINEAR", "DARE_TIES"], optional): The method to use for merging. Options:
- "TIES": Use the TIES method (https://arxiv.org/pdf/2306.01708)
- "DARE_LINEAR": Use the DARE method with linear interpolation (https://arxiv.org/pdf/2311.03099)
- "DARE_TIES": Use the DARE method for pruning, and the TIES method for merging.
"""
if len(task_state_dicts) != len(task_weights):
raise ValueError("Must provide a weight for each model.")
task_weights = torch.tensor(task_weights)
# Choose the merging method.
if merge_method == "TIES":
merge_fn = ties
elif merge_method == "DARE_LINEAR":
merge_fn = dare_linear
elif merge_method == "DARE_TIES":
merge_fn = dare_ties
else:
raise ValueError(f"Unknown merge method: {merge_method}")
out_state_dict: dict[str, torch.Tensor] = {}
for key in tqdm.tqdm(base_state_dict.keys()):
base_tensor = base_state_dict[key]
orig_dtype = base_tensor.dtype
# Calculate the diff between each task tensor and the base tensor.
task_diff_tensors = [state_dict[key] - base_tensor for state_dict in task_state_dicts]
merged_diff_tensor = merge_fn(
task_tensors=task_diff_tensors,
weights=task_weights,
density=density,
)
# Some of the merge_fn implementations may return a tensor with a different dtype than the original tensors.
# We cast the merged_diff_tensor back to the original dtype here.
out_state_dict[key] = (base_tensor + merged_diff_tensor).to(dtype=orig_dtype)
return out_state_dict
================================================
FILE: src/invoke_training/model_merge/scripts/extract_lora_from_model_diff.py
================================================
# This script is based on
# https://raw.githubusercontent.com/kohya-ss/sd-scripts/bfb352bc433326a77aca3124248331eb60c49e8c/networks/extract_lora_from_models.py
# That script was originally based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py
import argparse
import logging
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Literal
import peft
import torch
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTextModelWithProjection
from invoke_training._shared.accelerator.accelerator_utils import get_dtype_from_str
from invoke_training._shared.stable_diffusion.lora_checkpoint_utils import (
TEXT_ENCODER_TARGET_MODULES,
UNET_TARGET_MODULES,
save_sdxl_kohya_checkpoint,
)
from invoke_training._shared.stable_diffusion.model_loading_utils import (
PipelineVersionEnum,
from_pretrained_with_variant_fallback,
load_pipeline,
)
from invoke_training.model_merge.extract_lora import (
PEFT_BASE_LAYER_PREFIX,
extract_lora_from_diffs,
get_patched_base_weights_from_peft_model,
get_state_dict_diff,
)
from invoke_training.model_merge.utils.parse_model_arg import parse_model_arg
@dataclass
class StableDiffusionModel:
"""A helper class to store the submodels of a SD model that we are interested in for LoRA extraction."""
unet: UNet2DConditionModel | None = None
text_encoder: CLIPTextModel | None = None
text_encoder_2: CLIPTextModelWithProjection | None = None
def all_none(self) -> bool:
return self.unet is None and self.text_encoder is None and self.text_encoder_2 is None
def load_model(
logger: logging.Logger,
model_name_or_path: str,
model_type: PipelineVersionEnum,
variant: str | None,
dtype: torch.dtype,
) -> StableDiffusionModel:
sd_model = StableDiffusionModel()
model_path = Path(model_name_or_path)
if model_path.is_dir():
# model_path is a directory, so we'll try to load the submodels of interest from its subdirectories.
logger.info(f"'{model_name_or_path}' is a directory. Attempting to load submodels.")
for submodel_name, submodel_class in [
("unet", UNet2DConditionModel),
("text_encoder", CLIPTextModel),
("text_encoder_2", CLIPTextModelWithProjection),
]:
submodel_path: Path = model_path / submodel_name
if submodel_path.exists():
logger.info(f"Loading '{submodel_name}' from '{submodel_path}'.")
submodel = from_pretrained_with_variant_fallback(
logger=logger,
model_class=submodel_class,
model_name_or_path=submodel_path,
torch_dtype=dtype,
variant=variant,
local_files_only=True,
)
setattr(sd_model, submodel_name, submodel)
else:
logger.info(f"'{submodel_name}' not found in '{model_name_or_path}'. Skipping.")
continue
else:
# model_name_or_path is not a directory, so it is either:
# 1) a single checkpoint file
# 2) a HF model name
# Both can be loaded by calling load_pipeline.
logger.info(f"'{model_name_or_path}' is a single checkpoint file. Attempting to load.")
pipeline = load_pipeline(
logger=logger,
model_name_or_path=model_name_or_path,
pipeline_version=model_type,
torch_dtype=dtype,
variant=variant,
)
if isinstance(pipeline, StableDiffusionPipeline):
sd_model.unet = pipeline.unet
sd_model.text_encoder = pipeline.text_encoder
elif isinstance(pipeline, StableDiffusionXLPipeline):
sd_model.unet = pipeline.unet
sd_model.text_encoder = pipeline.text_encoder
sd_model.text_encoder_2 = pipeline.text_encoder_2
else:
raise RuntimeError(f"Unexpected pipeline type: {type(pipeline)}.")
if sd_model.all_none():
raise RuntimeError(f"Failed to load any submodels from '{model_name_or_path}'.")
return sd_model
def str_to_device(device_str: Literal["cuda", "cpu"]) -> torch.device:
if device_str == "cuda":
return torch.device("cuda")
elif device_str == "cpu":
return torch.device("cpu")
else:
raise ValueError(f"Unexpected device: {device_str}")
def state_dict_to_device(state_dict: dict[str, torch.Tensor], device: torch.device) -> dict[str, torch.Tensor]:
return {k: v.to(device=device) for k, v in state_dict.items()}
def extract_lora_from_submodel(
logger: logging.Logger,
model_orig: torch.nn.Module,
model_tuned: torch.nn.Module,
device: torch.device,
out_dtype: torch.dtype,
lora_target_modules: list[str],
lora_rank: int,
clamp_quantile: float = 0.99,
) -> peft.PeftModel:
"""Extract LoRA weights from the diff between model_orig and model_tuned. Returns a new model_orig, wrapped in a
PeftModel, with the LoRA weights applied.
"""
# Apply LoRA to the UNet.
# The only reason we do this is to get the module names for the weights that we'll extract. We don't actually use
# the LoRA weights initialized here.
unet_lora_config = peft.LoraConfig(
r=lora_rank,
# We set the alpha to the rank, because we don't want any scaling to be applied to the LoRA weights that we
# extract.
lora_alpha=lora_rank,
target_modules=lora_target_modules,
)
model_tuned = peft.get_peft_model(model_tuned, unet_lora_config)
model_orig = peft.get_peft_model(model_orig, unet_lora_config)
base_weights_tuned = get_patched_base_weights_from_peft_model(model_tuned)
base_weights_orig = get_patched_base_weights_from_peft_model(model_orig)
diffs = get_state_dict_diff(base_weights_tuned, base_weights_orig)
# Clear tuned model to save memory.
# TODO(ryand): We also need to clear the state_dicts. Move the diff extraction to a separate function so that memory
# cleanup is handled by scoping.
del model_tuned
# Apply SVD (Singluar Value Decomposition) to the diffs.
# We just use the device for this calculation, since it's slow, then we move the results back to the CPU.
logger.info("Calculating LoRA weights with SVD.")
diffs = state_dict_to_device(diffs, device)
# TODO(ryand): Should we skip if the diffs are all zeros? This would happen if two models are identical. This could
# happen if some submodels differ while others don't.
lora_weights = extract_lora_from_diffs(
diffs=diffs, rank=lora_rank, clamp_quantile=clamp_quantile, out_dtype=out_dtype
)
# Prepare state dict for LoRA.
lora_state_dict = {}
for module_name, (lora_up, lora_down) in lora_weights.items():
lora_state_dict[PEFT_BASE_LAYER_PREFIX + module_name + ".lora_A.default.weight"] = lora_down
lora_state_dict[PEFT_BASE_LAYER_PREFIX + module_name + ".lora_B.default.weight"] = lora_up
# The alpha value is set once globally in the PEFT model, so no need to set it for each module.
# lora_state_dict[peft_base_layer_suffix + module_name + ".alpha"] = torch.tensor(down_weight.size()[0])
lora_state_dict = state_dict_to_device(lora_state_dict, torch.device("cpu"))
# Load the state_dict into the LoRA model.
model_orig.load_state_dict(lora_state_dict, strict=False, assign=True)
return model_orig
@torch.no_grad()
def extract_lora(
logger: logging.Logger,
model_type: PipelineVersionEnum,
orig_model_name_or_path: str,
orig_model_variant: str | None,
tuned_model_name_or_path: str,
tuned_model_variant: str | None,
save_to: str,
load_precision: Literal["float32", "float16", "bfloat16"],
save_precision: Literal["float32", "float16", "bfloat16"],
device: Literal["cuda", "cpu"],
lora_rank: int,
clamp_quantile=0.99,
):
load_dtype = get_dtype_from_str(load_precision)
save_dtype = get_dtype_from_str(save_precision)
device = str_to_device(device)
orig_model = load_model(
logger=logger,
model_name_or_path=orig_model_name_or_path,
model_type=model_type,
dtype=load_dtype,
variant=orig_model_variant,
)
tuned_model = load_model(
logger=logger,
model_name_or_path=tuned_model_name_or_path,
model_type=model_type,
dtype=load_dtype,
variant=tuned_model_variant,
)
lora_models: dict[str, peft.PeftModel] = {}
for submodel_name, submodel_orig, submodel_tuned, lora_target_modules in [
("unet", orig_model.unet, tuned_model.unet, UNET_TARGET_MODULES),
("text_encoder", orig_model.text_encoder, tuned_model.text_encoder, TEXT_ENCODER_TARGET_MODULES),
("text_encoder_2", orig_model.text_encoder_2, tuned_model.text_encoder_2, TEXT_ENCODER_TARGET_MODULES),
]:
if submodel_orig is not None and submodel_tuned is not None:
logger.info(f"Extracting LoRA weights for '{submodel_name}'.")
lora_models[submodel_name] = extract_lora_from_submodel(
logger=logger,
model_orig=submodel_orig,
model_tuned=submodel_tuned,
device=device,
out_dtype=save_dtype,
lora_target_modules=lora_target_modules,
lora_rank=lora_rank,
clamp_quantile=clamp_quantile,
)
else:
logger.info(f"Skipping '{submodel_name}'.")
# Save the LoRA weights.
save_to_path = Path(save_to)
assert save_to_path.suffix == ".safetensors"
if save_to_path.exists():
raise FileExistsError(f"Destination file already exists: '{save_to}'.")
save_to_path.parent.mkdir(parents=True, exist_ok=True)
save_sdxl_kohya_checkpoint(
save_to_path,
unet=lora_models.get("unet", None),
text_encoder_1=lora_models.get("text_encoder", None),
text_encoder_2=lora_models.get("text_encoder_2", None),
)
logger.info(f"Saved LoRA weights to: {save_to_path}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-type",
type=str,
choices=["SD", "SDXL"],
help="The type of the models to merge ['SD', 'SDXL'].",
)
parser.add_argument(
"--model-orig",
type=str,
required=True,
help="Path or HF Hub name of the original model. The model must be in one of the following formats: "
"1) a single checkpoint file (e.g. '.safetensors') containing all submodels, "
"2) a model in diffusers format containing all submodels, "
"or 3) a model in diffusers format containing a subset of the submodels (e.g. only a UNet)."
"An HF variant can optionally be appended to the model name after a double-colon delimiter ('::')."
"E.g. '--model-orig runwayml/stable-diffusion-v1-5::fp16'",
)
parser.add_argument(
"--model-tuned",
type=str,
required=True,
help="Path or HF Hub name of the tuned model. The model must be in one of the following formats: "
"1) a single checkpoint file (e.g. '.safetensors') containing all submodels, "
"2) a model in diffusers format containing all submodels, "
"or 3) a model in diffusers format containing a subset of the submodels (e.g. only a UNet)."
"An HF variant can optionally be appended to the model name after a double-colon delimiter ('::')."
"E.g. '--model-orig runwayml/stable-diffusion-v1-5::fp16'",
)
parser.add_argument(
"--save-to",
type=str,
required=True,
help="Destination file path (must have a .safetensors extension).",
)
parser.add_argument(
"--load-precision",
type=str,
default="bfloat16",
choices=["float32", "float16", "bfloat16"],
help="Model load precision.",
)
parser.add_argument(
"--save-precision",
type=str,
default="float16",
choices=["float32", "float16", "bfloat16"],
help="Model save precision.",
)
parser.add_argument("--lora-rank", type=int, default=4, help="LoRA rank dimension.")
parser.add_argument("--clamp-quantile", type=float, default=0.99, help="Quantile clamping value. (0-1)")
parser.add_argument(
"--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device to use. (cuda or cpu)"
)
args = parser.parse_args()
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger()
orig_model_name_or_path, orig_model_variant = parse_model_arg(args.model_orig)
tuned_model_name_or_path, tuned_model_variant = parse_model_arg(args.model_tuned)
extract_lora(
logger=logger,
model_type=PipelineVersionEnum(args.model_type),
orig_model_name_or_path=orig_model_name_or_path,
orig_model_variant=orig_model_variant,
tuned_model_name_or_path=tuned_model_name_or_path,
tuned_model_variant=tuned_model_variant,
save_to=args.save_to,
load_precision=args.load_precision,
save_precision=args.save_precision,
device=args.device,
lora_rank=args.lora_rank,
clamp_quantile=args.clamp_quantile,
)
if __name__ == "__main__":
main()
================================================
FILE: src/invoke_training/model_merge/scripts/merge_lora_into_model.py
================================================
import argparse # noqa: I001
import logging
from pathlib import Path
import torch
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
# fmt: off
# HACK(ryand): Import order matters, because invokeai contains circular imports.
from invokeai.backend.model_manager.taxonomy import BaseModelType
from invokeai.backend.patches.layer_patcher import LayerPatcher
from invokeai.backend.patches.lora_conversions.sd_lora_conversion_utils import \
lora_model_from_sd_state_dict
from invokeai.backend.util.original_weights_storage import \
OriginalWeightsStorage
from safetensors.torch import load_file
# fmt: on
from invoke_training._shared.accelerator.accelerator_utils import get_dtype_from_str
from invoke_training._shared.stable_diffusion.model_loading_utils import PipelineVersionEnum, load_pipeline
from invoke_training.model_merge.utils.parse_model_arg import parse_model_arg
def to_invokeai_base_model_type(model_type: PipelineVersionEnum):
if model_type == PipelineVersionEnum.SD:
return BaseModelType.StableDiffusion1
elif model_type == PipelineVersionEnum.SDXL:
return BaseModelType.StableDiffusionXL
else:
raise ValueError(f"Unexpected model_type: {model_type}")
@torch.no_grad()
def merge_lora_into_sd_model(
logger: logging.Logger,
model_type: PipelineVersionEnum,
base_model: str,
base_model_variant: str | None,
lora_models: list[tuple[str, float]],
output: str,
save_dtype: str,
):
pipeline: StableDiffusionXLPipeline | StableDiffusionPipeline = load_pipeline(
logger=logger, model_name_or_path=base_model, pipeline_version=model_type, variant=base_model_variant
)
save_dtype = get_dtype_from_str(save_dtype)
logger.info(f"Loaded base model: '{base_model}'.")
pipeline.to(save_dtype)
models: list[torch.nn.Module] = []
lora_prefixes: list[str] = []
if isinstance(pipeline, StableDiffusionPipeline):
models = [pipeline.unet, pipeline.text_encoder]
lora_prefixes = ["lora_unet_", "lora_te_"]
elif isinstance(pipeline, StableDiffusionXLPipeline):
models = [pipeline.unet, pipeline.text_encoder, pipeline.text_encoder_2]
lora_prefixes = ["lora_unet_", "lora_te1_", "lora_te2_"]
else:
raise ValueError(f"Unexpected pipeline type: {type(pipeline)}")
# Although we are not unpatching, the patcher might require this. Initialize empty.
original_weights = OriginalWeightsStorage()
for lora_model_path, lora_model_weight in lora_models:
# Load state dict from file
lora_path = Path(lora_model_path)
if lora_path.suffix == ".safetensors":
state_dict = load_file(lora_path.absolute().as_posix(), device="cpu")
else:
# Assuming .ckpt, .pt, .bin etc. are torch checkpoints
state_dict = torch.load(lora_path, map_location="cpu")
# Convert state dict to ModelPatchRaw
lora_model = lora_model_from_sd_state_dict(state_dict=state_dict)
# Apply the patch using LayerPatcher
for model, lora_prefix in zip(models, lora_prefixes, strict=True):
LayerPatcher.apply_smart_model_patch(
model=model,
prefix=lora_prefix,
patch=lora_model,
patch_weight=lora_model_weight,
original_weights=original_weights, # Pass storage, even if unused for merging
original_modules={}, # Pass empty dict, not needed for direct patching/merging
dtype=model.dtype, # Use the model's dtype
# Force direct patching since we are merging into the main weights
force_direct_patching=True,
force_sidecar_patching=False,
)
logger.info(f"Applied LoRA model '{lora_model_path}' with weight {lora_model_weight}.")
output_path = Path(output)
output_path.mkdir(parents=True)
# TODO(ryand): Should we keep the base model variant? This is clearly a flawed assumption.
pipeline.save_pretrained(output_path, variant=base_model_variant)
logger.info(f"Saved merged model to '{output_path}'.")
def parse_lora_model_arg(lora_model_arg: str) -> tuple[str, float]:
"""Parse a --lora-model argument into a tuple of the model path and weight."""
parts = lora_model_arg.split("::")
if len(parts) == 1:
return parts[0], 1.0
elif len(parts) == 2:
return parts[0], float(parts[1])
else:
raise ValueError(f"Unexpected format for --lora-model arg: '{lora_model_arg}'.")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-type",
type=str,
choices=["SD", "SDXL"],
help="The type of the models to merge ['SD', 'SDXL'].",
)
parser.add_argument(
"--base-model",
type=str,
help="The base model to merge LoRAs into. The model can be either 1) an HF hub name, 2) a path to a local "
"diffusers model directory, or 3) a path to a single checkpoint file. An HF variant can optionally be appended "
"to the model name after a double-colon delimiter ('::')."
"E.g. '--base-model runwayml/stable-diffusion-v1-5::fp16'",
required=True,
)
parser.add_argument(
"--lora-models",
type=str,
nargs="+",
help="The path(s) to one or more LoRA models to merge into the base model. Model weights can be appended to "
"the path, separated by a double colon ('::'). The weight is optional and defaults to 1.0. E.g. "
"'--lora-models path/to/lora_model_1.safetensors::0.5 path/to/lora_model_2.safetensors'.",
required=True,
)
parser.add_argument(
"--output",
type=str,
help="The path to an output directory where the merged model will be saved (in diffusers format).",
)
parser.add_argument(
"--save-dtype",
type=str,
default="float16",
choices=["float32", "float16", "bfloat16"],
help="The dtype to save the model as.",
)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()
base_model, base_model_variant = parse_model_arg(args.base_model)
lora_models = [parse_lora_model_arg(arg) for arg in args.lora_models]
# Log the parsed arguments
logger.info(f"Model type: {args.model_type}")
logger.info(f"Base model: {base_model}")
logger.info(f"Base model variant: {base_model_variant}")
logger.info(f"Output directory: {args.output}")
logger.info(f"Save dtype: {args.save_dtype}")
lora_models_str = " - " + "\n - ".join([f"{model} ({weight})" for model, weight in lora_models])
logger.info(f"LoRA models:\n{lora_models_str}")
merge_lora_into_sd_model(
logger=logger,
model_type=PipelineVersionEnum(args.model_type),
base_model=base_model,
base_model_variant=base_model_variant,
lora_models=lora_models,
output=args.output,
save_dtype=args.save_dtype,
)
if __name__ == "__main__":
main()
================================================
FILE: src/invoke_training/model_merge/scripts/merge_models.py
================================================
import argparse
import logging
from dataclasses import dataclass
from pathlib import Path
import torch
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
from invoke_training._shared.accelerator.accelerator_utils import get_dtype_from_str
from invoke_training._shared.stable_diffusion.model_loading_utils import PipelineVersionEnum, load_pipeline
from invoke_training.model_merge.merge_models import merge_models
from invoke_training.model_merge.utils.parse_model_arg import parse_model_arg
@dataclass
class MergeModel:
model_name_or_path: str
variant: str | None
weight: float
def run_merge_models(
logger: logging.Logger,
model_type: PipelineVersionEnum,
models: list[MergeModel],
method: str,
out_dir: str,
dtype: torch.dtype,
):
# Create the output directory if it doesn't exist.
out_dir_path = Path(out_dir)
out_dir_path.mkdir(parents=True, exist_ok=False)
# Load the models.
loaded_models: list[StableDiffusionPipeline] | list[StableDiffusionXLPipeline] = []
for model in models:
loaded_model = load_pipeline(
logger=logger,
model_name_or_path=model.model_name_or_path,
pipeline_version=model_type,
torch_dtype=dtype,
variant=model.variant,
)
loaded_models.append(loaded_model)
# Select the submodels to merge.
if model_type == PipelineVersionEnum.SDXL:
submodel_names = ["unet", "text_encoder", "text_encoder_2"]
elif model_type == PipelineVersionEnum.SD:
submodel_names = ["unet", "text_encoder"]
else:
raise ValueError(f"Unexpected model type: {model_type}")
# Merge the models.
weights = [model.weight for model in models]
for submodel_name in submodel_names:
submodels: list[torch.nn.Module] = [getattr(loaded_model, submodel_name) for loaded_model in loaded_models]
submodel_state_dicts: list[dict[str, torch.Tensor]] = [submodel.state_dict() for submodel in submodels]
logger.info(f"Merging {submodel_name} state_dicts...")
merged_state_dict = merge_models(state_dicts=submodel_state_dicts, weights=weights, merge_method=method)
# Merge the merged_state_dict back into the first pipeline to keep memory utilization low.
submodels[0].load_state_dict(merged_state_dict, assign=True)
logger.info(f"Merged {submodel_name} state_dicts.")
# Save the merged model.
logger.info("Saving result...")
loaded_models[0].save_pretrained(out_dir_path)
logger.info(f"Saved merged model to '{out_dir_path}'.")
def parse_model_args(models: list[str], weights: list[str]) -> list[MergeModel]:
"""Parse a list of --models arguments and --weights arguments into a list of MergeModels."""
merge_model_list: list[MergeModel] = []
for model, weight in zip(models, weights, strict=True):
parsed_model, parsed_variant = parse_model_arg(model)
merge_model_list.append(
MergeModel(model_name_or_path=parsed_model, variant=parsed_variant, weight=float(weight))
)
return merge_model_list
def main():
parser = argparse.ArgumentParser()
# TODO(ryand): Auto-detect the model-type.
parser.add_argument(
"--model-type",
type=str,
choices=["SD", "SDXL"],
help="The type of the models to merge ['SD', 'SDXL'].",
)
parser.add_argument(
"--models",
nargs="+",
type=str,
required=True,
help="Two or more models to merge. Each model can be either 1) an HF hub name, 2) a path to a local diffusers "
"model directory, or 3) a path to a single checkpoint file. An HF variant can optionally be appended to the "
"model name after a double-colon delimiter ('::')."
"E.g. '--models runwayml/stable-diffusion-v1-5::fp16 path/to/local/model.safetensors'",
)
parser.add_argument(
"--weights",
nargs="+",
type=float,
required=True,
help="The weights for each model. The weights will be normalized to sum to 1. "
"For example, to merge weights with equal weights: '--weights 1.0 1.0'. "
"To weight the first model more heavily: '--weights 0.75 0.25'.",
)
parser.add_argument(
"--method",
type=str,
default="LERP",
choices=["LERP", "SLERP"],
help="The merge method to use. Options: 'LERP' (linear interpolation) or 'SLERP' (spherical linear "
"interpolation).",
)
parser.add_argument(
"--out-dir",
type=str,
required=True,
help="The output directory where the merged model will be written (in diffusers format).",
)
parser.add_argument(
"--dtype",
help="The torch dtype that will be used for all calculations and for the output model.",
type=str,
default="float16",
choices=["float32", "float16", "bfloat16"],
)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
merge_model_list = parse_model_args(args.models, args.weights)
run_merge_models(
logger=logger,
model_type=PipelineVersionEnum(args.model_type),
models=merge_model_list,
method=args.method,
out_dir=args.out_dir,
dtype=get_dtype_from_str(args.dtype),
)
if __name__ == "__main__":
main()
================================================
FILE: src/invoke_training/model_merge/scripts/merge_task_models_to_base_model.py
================================================
import argparse
import logging
from pathlib import Path
import torch
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
from invoke_training._shared.accelerator.accelerator_utils import get_dtype_from_str
from invoke_training._shared.stable_diffusion.model_loading_utils import PipelineVersionEnum, load_pipeline
from invoke_training.model_merge.merge_tasks_to_base import merge_tasks_to_base_model
from invoke_training.model_merge.scripts.merge_models import MergeModel, parse_model_args
def run_merge_models(
logger: logging.Logger,
model_type: PipelineVersionEnum,
base_model: MergeModel,
task_models: list[MergeModel],
method: str,
density: float,
out_dir: str,
dtype: torch.dtype,
):
# Create the output directory if it doesn't exist.
out_dir_path = Path(out_dir)
out_dir_path.mkdir(parents=True, exist_ok=False)
# Load the base model.
loaded_base_model = load_pipeline(
logger=logger,
model_name_or_path=base_model.model_name_or_path,
pipeline_version=model_type,
torch_dtype=dtype,
variant=base_model.variant,
)
# Load the task models.
loaded_task_models: list[StableDiffusionPipeline] | list[StableDiffusionXLPipeline] = []
for task_model in task_models:
loaded_task_model = load_pipeline(
logger=logger,
model_name_or_path=task_model.model_name_or_path,
pipeline_version=model_type,
torch_dtype=dtype,
variant=task_model.variant,
)
loaded_task_models.append(loaded_task_model)
# Select the submodels to merge.
if model_type == PipelineVersionEnum.SDXL:
submodel_names = ["unet", "text_encoder", "text_encoder_2"]
elif model_type == PipelineVersionEnum.SD:
submodel_names = ["unet", "text_encoder"]
else:
raise ValueError(f"Unexpected model type: {model_type}")
# Merge the models.
task_model_weights = [task_model.weight for task_model in task_models]
for submodel_name in submodel_names:
base_submodel: torch.nn.Module = getattr(loaded_base_model, submodel_name)
base_submodel_state_dict = base_submodel.state_dict()
task_submodels: list[torch.nn.Module] = [
getattr(loaded_task_model, submodel_name) for loaded_task_model in loaded_task_models
]
task_submodel_state_dict = [submodel.state_dict() for submodel in task_submodels]
logger.info(f"Merging {submodel_name} state_dicts...")
merged_state_dict = merge_tasks_to_base_model(
base_state_dict=base_submodel_state_dict,
task_state_dicts=task_submodel_state_dict,
task_weights=task_model_weights,
density=density,
merge_method=method,
)
# Merge the merged_state_dict back into the base model pipeline to keep memory utilization low.
base_submodel.load_state_dict(merged_state_dict, assign=True)
logger.info(f"Merged {submodel_name} state_dicts.")
# Delete the task models to free up memory.
# At the time of the writing, the save_pretrained(...) function below caused a large spike in memory usage. We free
# the task models to increase its likelihood of success.
del loaded_task_models
# Save the merged model.
logger.info("Saving result...")
loaded_base_model.save_pretrained(out_dir_path)
logger.info(f"Saved merged model to '{out_dir_path}'.")
def main():
parser = argparse.ArgumentParser()
# TODO(ryand): Auto-detect the base-model-type.
parser.add_argument(
"--model-type",
type=str,
choices=["SD", "SDXL"],
help="The type of the models to merge ['SD', 'SDXL'].",
)
parser.add_argument(
"--base-model",
type=str,
help="The base model to merge task-specific models into. Can be either 1) an HF hub name, 2) a path to a local "
"diffusers model directory, or 3) a path to a single checkpoint file. An HF variant can optionally be appended "
"to the model name after a double-colon delimiter ('::')."
"E.g. '--base-model runwayml/stable-diffusion-v1-5::fp16'.",
)
parser.add_argument(
"--task-models",
nargs="+",
type=str,
required=True,
help="One or more task-specific models to merge into the base model. Each model can be either 1) an HF hub "
"name, 2) a path to a local diffusers model directory, or 3) a path to a single checkpoint file. An HF variant "
"can optionally be appended to the model name after a double-colon delimiter ('::')."
"E.g. '--task-models runwayml/stable-diffusion-v1-5::fp16 path/to/local/model.safetensors'",
)
parser.add_argument(
"--task-weights",
nargs="+",
type=float,
required=True,
help="The weights for each task model. The weights are multipliers applied to the diff between each task model "
"and the base model. As a starting point, it is recommended to use a weight of 1.0 for all task models, e.g. "
"'--task-weights 1.0 1.0'. The weights can then be tuned from there, e.g. '--task-weights 1.0 1.3'.",
)
parser.add_argument(
"--method",
type=str,
default="TIES",
choices=["TIES", "DARE_LINEAR", "DARE_TIES"],
help="The merge method to use. Options: ['TIES', 'DARE_LINEAR', 'DARE_TIES'].",
)
parser.add_argument(
"--density",
type=float,
default=0.2,
help="The fraction of values to preserve in the prune/trim step of DARE/TIES methods. Should be in the range "
"[0, 1].",
)
parser.add_argument(
"--out-dir",
type=str,
required=True,
help="The output directory where the merged model will be written (in diffusers format).",
)
parser.add_argument(
"--dtype",
help="The torch dtype that will be used for all calculations and for the output model.",
type=str,
default="float16",
choices=["float32", "float16", "bfloat16"],
)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger()
base_model = parse_model_args([args.base_model], [1.0])[0]
task_models = parse_model_args(args.task_models, args.task_weights)
run_merge_models(
logger=logger,
model_type=PipelineVersionEnum(args.model_type),
base_model=base_model,
task_models=task_models,
method=args.method,
density=args.density,
out_dir=args.out_dir,
dtype=get_dtype_from_str(args.dtype),
)
if __name__ == "__main__":
main()
================================================
FILE: src/invoke_training/model_merge/utils/__init__.py
================================================
================================================
FILE: src/invoke_training/model_merge/utils/normalize_weights.py
================================================
def normalize_weights(weights: list[float]) -> list[float]:
total = sum(weights)
return [weight / total for weight in weights]
================================================
FILE: src/invoke_training/model_merge/utils/parse_model_arg.py
================================================
def parse_model_arg(model: str, delimiter: str = "::") -> tuple[str, str | None]:
"""Parse a model argument into a model and a variant."""
parts = model.split(delimiter)
if len(parts) == 1:
return parts[0], None
elif len(parts) == 2:
return parts[0], parts[1]
else:
raise ValueError(f"Unexpected format for --models arg: '{model}'.")
================================================
FILE: src/invoke_training/pipelines/__init__.py
================================================
================================================
FILE: src/invoke_training/pipelines/_experimental/sd_dpo_lora/config.py
================================================
from typing import Annotated, Literal, Union
from pydantic import Field, model_validator
from invoke_training.config.base_pipeline_config import BasePipelineConfig
from invoke_training.config.config_base_model import ConfigBaseModel
from invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig, ProdigyOptimizerConfig
class HFHubImagePairPreferenceDatasetConfig(ConfigBaseModel):
type: Literal["HF_HUB_IMAGE_PAIR_PREFERENCE_DATASET"] = "HF_HUB_IMAGE_PAIR_PREFERENCE_DATASET"
# TODO(ryand): Fill this out.
class ImagePairPreferenceDatasetConfig(ConfigBaseModel):
type: Literal["IMAGE_PAIR_PREFERENCE_DATASET"] = "IMAGE_PAIR_PREFERENCE_DATASET"
dataset_dir: str
"""The directory to load the dataset from."""
class ImagePairPreferenceSDDataLoaderConfig(ConfigBaseModel):
type: Literal["IMAGE_PAIR_PREFERENCE_SD_DATA_LOADER"] = "IMAGE_PAIR_PREFERENCE_SD_DATA_LOADER"
dataset: Annotated[
Union[HFHubImagePairPreferenceDatasetConfig, ImagePairPreferenceDatasetConfig], Field(discriminator="type")
]
resolution: int | tuple[int, int] = 512
"""The resolution for input images. Either a scalar integer representing the square resolution height and width, or
a (height, width) tuple. All of the images in the dataset will be resized to this resolution unless the
`aspect_ratio_buckets` config is set.
"""
center_crop: bool = True
"""If True, input images will be center-cropped to the target resolution.
If False, input images will be randomly cropped to the target resolution.
"""
random_flip: bool = False
"""Whether random flip augmentations should be applied to input images.
"""
dataloader_num_workers: int = 0
"""Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.
"""
class SdDirectPreferenceOptimizationLoraConfig(BasePipelineConfig):
type: Literal["SD_DIRECT_PREFERENCE_OPTIMIZATION_LORA"] = "SD_DIRECT_PREFERENCE_OPTIMIZATION_LORA"
model: str = "runwayml/stable-diffusion-v1-5"
"""Name or path of the base model to train. Can be in diffusers format, or a single stable diffusion checkpoint
file. (E.g. 'runwayml/stable-diffusion-v1-5', '/path/to/realisticVisionV51_v51VAE.safetensors', etc. )
"""
hf_variant: str | None = "fp16"
"""The Hugging Face Hub model variant to use. Only applies if `model` is a Hugging Face Hub model name.
"""
# Note: Pydantic handles mutable default values well:
# https://docs.pydantic.dev/latest/concepts/models/#fields-with-non-hashable-default-values
base_embeddings: dict[str, str] = {}
"""A mapping of embedding tokens to trained embedding file paths. These embeddings will be applied to the base model
before training.
Example:
```
base_embeddings = {
"bruce_the_gnome": "/path/to/bruce_the_gnome.safetensors",
}
```
Consider also adding the embedding tokens to the `data_loader.caption_prefix` if they are not already present in the
dataset captions.
Note that the embeddings themselves are not fine-tuned further, but they will impact the LoRA model training if they
are referenced in the dataset captions. The list of embeddings provided here should be the same list used at
generation time with the resultant LoRA model.
"""
lora_checkpoint_format: Literal["invoke_peft", "kohya"] = "kohya"
"""The format of the LoRA checkpoint to save. Choose between `invoke_peft` or `kohya`."""
train_unet: bool = True
"""Whether to add LoRA layers to the UNet model and train it.
"""
train_text_encoder: bool = True
"""Whether to add LoRA layers to the text encoder and train it.
"""
optimizer: AdamOptimizerConfig | ProdigyOptimizerConfig = AdamOptimizerConfig()
text_encoder_learning_rate: float | None = None
"""The learning rate to use for the text encoder model. If set, this overrides the optimizer's default learning
rate. Set to null or 0 to use the optimizer's default learning rate.
"""
unet_learning_rate: float | None = None
"""The learning rate to use for the UNet model. If set, this overrides the optimizer's default learning rate.
Set to null or 0 to use the optimizer's default learning rate.
"""
lr_scheduler: Literal[
"linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"
] = "constant"
lr_warmup_steps: int = 0
"""The number of warmup steps in the learning rate scheduler. Only applied to schedulers that support warmup.
See lr_scheduler.
"""
min_snr_gamma: float | None = 5.0
"""Min-SNR weighting for diffusion training was introduced in https://arxiv.org/abs/2303.09556. This strategy
improves the speed of training convergence by adjusting the weight of each sample.
`min_snr_gamma` acts like an an upper bound on the weight of samples with low noise levels.
If `None`, then Min-SNR weighting will not be applied. If enabled, the recommended value is `min_snr_gamma = 5.0`.
"""
lora_rank_dim: int = 4
"""The rank dimension to use for the LoRA layers. Increasing the rank dimension increases the model's expressivity,
but also increases the size of the generated LoRA model.
"""
cache_text_encoder_outputs: bool = False
"""If True, the text encoder(s) will be applied to all of the captions in the dataset before starting training and
the results will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the
text encoders in VRAM), and speeds up training (don't have to run the text encoders for each training example).
This option can only be enabled if `train_text_encoder == False` and there are no caption augmentations being
applied.
"""
cache_vae_outputs: bool = False
"""If True, the VAE will be applied to all of the images in the dataset before starting training and the results
will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the VAE in VRAM), and
speeds up training (don't have to run the VAE encoding step). This option can only be enabled if all
non-deterministic image augmentations are disabled (i.e. center_crop=True, random_flip=False).
"""
enable_cpu_offload_during_validation: bool = False
"""If True, models will be kept in CPU memory and loaded into GPU memory one-by-one while generating validation
images. This reduces VRAM requirements at the cost of slower generation of validation images.
"""
gradient_accumulation_steps: int = 1
"""The number of gradient steps to accumulate before each weight update. This value is passed to Hugging Face
Accelerate. This is an alternative to increasing the batch size when training with limited VRAM.
"""
weight_dtype: Literal["float32", "float16", "bfloat16"] = "bfloat16"
"""All weights (trainable and fixed) will be cast to this precision. Lower precision dtypes require less VRAM, and
result in faster training, but are more prone to issues with numerical stability.
Recommendations:
- `"float32"`: Use this mode if you have plenty of VRAM available.
- `"bfloat16"`: Use this mode if you have limited VRAM and a GPU that supports bfloat16.
- `"float16"`: Use this mode if you have limited VRAM and a GPU that does not support bfloat16.
See also [`mixed_precision`][invoke_training.pipelines._experimental.sd_dpo_lora.config.SdDirectPreferenceOptimizationLoraConfig.mixed_precision].
""" # noqa: E501
mixed_precision: Literal["no", "fp16", "bf16", "fp8"] = "no"
"""The mixed precision mode to use.
If mixed precision is enabled, then all non-trainable parameters will be cast to the specified `weight_dtype`, and
trainable parameters are kept in float32 precision to avoid issues with numerical stability.
This value is passed to Hugging Face Accelerate. See
[`accelerate.Accelerator.mixed_precision`](https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.mixed_precision)
for more details.
""" # noqa: E501
xformers: bool = False
"""If true, use xformers for more efficient attention blocks.
"""
gradient_checkpointing: bool = False
"""Whether or not to use gradient checkpointing to save memory at the expense of a slower backward pass. Enabling
gradient checkpointing slows down training by ~20%.
"""
max_checkpoints: int | None = None
"""The maximum number of checkpoints to keep. New checkpoints will replace earlier checkpoints to stay under this
limit. Note that this limit is applied to 'step' and 'epoch' checkpoints separately.
"""
prediction_type: Literal["epsilon", "v_prediction"] | None = None
"""The prediction_type that will be used for training. Choose between 'epsilon' or 'v_prediction' or leave 'None'.
If 'None', the prediction type of the scheduler: `noise_scheduler.config.prediction_type` is used.
"""
max_grad_norm: float | None = None
"""Max gradient norm for clipping. Set to null or 0 for no clipping.
"""
validation_prompts: list[str] = []
"""A list of prompts that will be used to generate images throughout training for the purpose of tracking progress.
See also 'validate_every_n_epochs'.
"""
negative_validation_prompts: list[str] | None = None
"""A list of negative prompts that will be applied when generating validation images. If set, this list should have
the same length as 'validation_prompts'.
"""
num_validation_images_per_prompt: int = 4
"""The number of validation images to generate for each prompt in 'validation_prompts'. Careful, validation can
become quite slow if this number is too large.
"""
train_batch_size: int = 4
"""The training batch size.
"""
data_loader: ImagePairPreferenceSDDataLoaderConfig
initial_lora: str | None = None
"""The LoRA checkpoint directory to initialize the LoRA weights from.
If set, the following configuration parameters are ignored:
- `train_unet`: The UNet will be trained if it is present in `initial_lora`.
- `train_text_encoder`: The text encoder will be trained if it is present in `initial_lora`.
- `lora_rank_dim`: The LoRA rank dimension from `initial_lora` will be used.
Currently only LoRA checkpoints in the internal `invoke-training` PEFT format are supported (i.e. checkpoints
generated by an `invoke-training` training pipeline).
"""
beta: float = 5000.0
"""The beta parameter, as defined in (https://arxiv.org/pdf/2311.12908.pdf). Larger beta values increase the
KL-Divergence penalty, discouraging divergence from the reference model weights.
Typical values for `beta` are in the range [1000.0, 10000.0].
"""
@model_validator(mode="after")
def check_validation_prompts(self):
if self.negative_validation_prompts is not None and len(self.negative_validation_prompts) != len(
self.validation_prompts
):
raise ValueError(
f"The number of validation_prompts ({len(self.validation_prompts)}) must match the number of "
f"negative_validation_prompts ({len(self.negative_validation_prompts)})."
)
return self
================================================
FILE: src/invoke_training/pipelines/_experimental/sd_dpo_lora/train.py
================================================
import copy
import itertools
import json
import logging
import math
import os
import tempfile
import time
from pathlib import Path
from typing import Literal
import peft
import torch
import torch.utils.data
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from invoke_training._shared.accelerator.accelerator_utils import (
get_dtype_from_str,
initialize_accelerator,
initialize_logging,
)
from invoke_training._shared.checkpoints.checkpoint_tracker import CheckpointTracker
from invoke_training._shared.data.data_loaders.image_pair_preference_sd_dataloader import (
build_image_pair_preference_sd_dataloader,
)
from invoke_training._shared.optimizer.optimizer_utils import initialize_optimizer
from invoke_training._shared.stable_diffusion.lora_checkpoint_utils import (
TEXT_ENCODER_TARGET_MODULES,
UNET_TARGET_MODULES,
load_sd_peft_checkpoint,
save_sd_kohya_checkpoint,
save_sd_peft_checkpoint,
)
from invoke_training._shared.stable_diffusion.model_loading_utils import load_models_sd
from invoke_training._shared.stable_diffusion.tokenize_captions import tokenize_captions
from invoke_training._shared.stable_diffusion.validation import generate_validation_images_sd
from invoke_training._shared.utils.import_xformers import import_xformers
from invoke_training.pipelines._experimental.sd_dpo_lora.config import SdDirectPreferenceOptimizationLoraConfig
from invoke_training.pipelines.callbacks import PipelineCallbacks
from invoke_training.pipelines.stable_diffusion.lora.train import cache_text_encoder_outputs
def _save_sd_lora_checkpoint(
epoch: int,
step: int,
unet: peft.PeftModel | None,
text_encoder: peft.PeftModel | None,
logger: logging.Logger,
checkpoint_tracker: CheckpointTracker,
lora_checkpoint_format: Literal["invoke_peft", "kohya"],
):
# Prune checkpoints and get new checkpoint path.
num_pruned = checkpoint_tracker.prune(1)
if num_pruned > 0:
logger.info(f"Pruned {num_pruned} checkpoint(s).")
save_path = checkpoint_tracker.get_path(epoch=epoch, step=step)
if lora_checkpoint_format == "invoke_peft":
save_sd_peft_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder)
elif lora_checkpoint_format == "kohya":
save_sd_kohya_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder)
else:
raise ValueError(f"Unsupported lora_checkpoint_format: '{lora_checkpoint_format}'.")
def train_forward_dpo( # noqa: C901
config: SdDirectPreferenceOptimizationLoraConfig,
data_batch: dict,
vae: AutoencoderKL,
noise_scheduler: DDPMScheduler,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
unet: UNet2DConditionModel,
ref_text_encoder: CLIPTextModel,
ref_unet: UNet2DConditionModel,
weight_dtype: torch.dtype,
) -> torch.Tensor:
"""Run the forward training pass for a single data_batch.
This forward pass is based on 'Diffusion Model Alignment Using Direct Preference Optimization'
(https://arxiv.org/pdf/2311.12908.pdf). See the "Pseudocode for Training Objective" Appendix section for a helpful
reference.
Returns:
torch.Tensor: Loss
"""
batch_size = data_batch["image_0"].shape[0]
# Concatenate image_0 and image_1 images into a single image batch.
images = torch.concat((data_batch["image_0"], data_batch["image_1"]))
# Re-order images so that the 'images' batch contains all winner images followed by all loser images.
w_indices = []
l_indices = []
prefer_0 = data_batch["prefer_0"]
prefer_1 = data_batch["prefer_1"]
for i in range(batch_size):
if prefer_0[i] and not prefer_1[i]:
w_indices.append(i)
l_indices.append(i + batch_size)
elif not prefer_0[i] and prefer_1[i]:
w_indices.append(i + batch_size)
l_indices.append(i)
else:
raise ValueError(f"Encountered image pair with prefer_0={prefer_0[i]} and prefer_1={prefer_1[i]}.")
images = images[w_indices + l_indices]
# Update batch_size in case image pairs were filtered due to no-preference.
batch_size = images.shape[0] // 2
# Convert images to latent space.
# The VAE output may have been cached and included in the data_batch. If not, we calculate it here.
latents = data_batch.get("vae_output", None)
if latents is None:
latents = vae.encode(images.to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents.
# We want to use the same noise for the winning and losing example in each pair, so we generate noise for the
# winning latents and then repeat it.
noise = torch.randn_like(latents[:batch_size])
noise = noise.repeat((2, 1, 1, 1))
# Sample a random timestep for each image **pair**.
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (batch_size,), device=latents.device)
timesteps = timesteps.repeat((2,)).long()
# Add noise to the latents according to the noise magnitude at each timestep (this is the forward
# diffusion process).
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning (for both the text_encoder and ref_text_encoder).
# The text_encoder_output may have been cached and included in the data_batch. If not, we calculate it here.
encoder_hidden_states = data_batch.get("text_encoder_output", None)
if encoder_hidden_states is None:
caption_token_ids = tokenize_captions(tokenizer, data_batch["caption"]).to(text_encoder.device)
encoder_hidden_states = text_encoder(caption_token_ids)[0].to(dtype=weight_dtype)
ref_encoder_hidden_states = ref_text_encoder(caption_token_ids)[0].to(dtype=weight_dtype)
encoder_hidden_states = encoder_hidden_states.repeat((2, 1, 1))
ref_encoder_hidden_states = ref_encoder_hidden_states.repeat((2, 1, 1))
# Get the target for loss depending on the prediction type.
if config.prediction_type is not None:
# Set the prediction_type of scheduler if it's defined in config.
noise_scheduler.register_to_config(prediction_type=config.prediction_type)
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# Predict the noise residual.
ref_model_pred: torch.Tensor = ref_unet(noisy_latents, timesteps, ref_encoder_hidden_states).sample
model_pred: torch.Tensor = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if "loss_weight" in data_batch:
raise NotImplementedError("loss_weight is not yet supported.")
target = target.float()
w_target = target[:batch_size]
l_target = target[batch_size:]
model_w_pred = model_pred[:batch_size]
model_l_pred = model_pred[batch_size:]
ref_w_pred = ref_model_pred[:batch_size]
ref_l_pred = ref_model_pred[batch_size:]
# The pseudo-code from the paper uses `.norm().pow(2)` to calculate the errors. We take the mean over all pixels
# rather than the sum over all pixels instead. This helps keep the learning rate stable across different image
# resolutions. It also means that the the recommended settings for beta from the paper are not correct.
# > model_w_err = (model_w_pred - target).norm().pow(2)
# > model_l_err = (model_l_pred - target).norm().pow(2)
# > ref_w_err = (ref_w_pred - target).norm().pow(2)
# > ref_l_err = (ref_l_pred - target).norm().pow(2)
model_w_err = torch.nn.functional.mse_loss(model_w_pred, w_target)
model_l_err = torch.nn.functional.mse_loss(model_l_pred, l_target)
ref_w_err = torch.nn.functional.mse_loss(ref_w_pred, w_target)
ref_l_err = torch.nn.functional.mse_loss(ref_l_pred, l_target)
w_diff = model_w_err - ref_w_err
l_diff = model_l_err - ref_l_err
inside_term = -1 * config.beta * (w_diff - l_diff)
loss = -1 * torch.nn.functional.logsigmoid(inside_term)
return loss
def train(config: SdDirectPreferenceOptimizationLoraConfig, callbacks: list[PipelineCallbacks] | None = None): # noqa: C901
if callbacks:
raise ValueError(f"This pipeline does not support callbacks, but {len(callbacks)} were provided.")
# Give a clear error message if an unsupported base model was chosen.
# TODO(ryan): Update this check to work with single-file SD checkpoints.
# check_base_model_version(
# {BaseModelVersionEnum.STABLE_DIFFUSION_V1, BaseModelVersionEnum.STABLE_DIFFUSION_V2},
# config.model,
# local_files_only=False,
# )
# Create a timestamped directory for all outputs.
out_dir = os.path.join(config.base_output_dir, f"{time.time()}")
ckpt_dir = os.path.join(out_dir, "checkpoints")
os.makedirs(ckpt_dir)
accelerator = initialize_accelerator(
out_dir, config.gradient_accumulation_steps, config.mixed_precision, config.report_to
)
logger = initialize_logging(os.path.basename(__file__), accelerator)
# Set the accelerate seed.
if config.seed is not None:
set_seed(config.seed)
# Log the accelerator configuration from every process to help with debugging.
logger.info(accelerator.state, main_process_only=False)
logger.info("Starting LoRA Training.")
logger.info(f"Configuration:\n{json.dumps(config.dict(), indent=2, default=str)}")
logger.info(f"Output dir: '{out_dir}'")
# Write the configuration to disk.
with open(os.path.join(out_dir, "config.json"), "w") as f:
json.dump(config.dict(), f, indent=2, default=str)
weight_dtype = get_dtype_from_str(config.weight_dtype)
logger.info("Loading models.")
tokenizer, noise_scheduler, text_encoder, vae, unet = load_models_sd(
logger=logger,
model_name_or_path=config.model,
hf_variant=config.hf_variant,
base_embeddings=config.base_embeddings,
dtype=weight_dtype,
)
ref_text_encoder = copy.deepcopy(text_encoder)
ref_unet = copy.deepcopy(unet)
if config.xformers:
import_xformers()
# TODO(ryand): There is a known issue if xformers is enabled when training in mixed precision where xformers
# will fail because Q, K, V have different dtypes.
unet.enable_xformers_memory_efficient_attention()
ref_unet.enable_xformers_memory_efficient_attention()
vae.enable_xformers_memory_efficient_attention()
# Prepare text encoder output cache.
text_encoder_output_cache_dir_name = None
if config.cache_text_encoder_outputs:
# TODO(ryand): Think about how to better check if it is safe to cache the text encoder outputs. Currently, there
# are a number of configurations that would cause variation in the text encoder outputs and should not be used
# with caching.
# TODO(ryand): This check does not make sense when config.initial_lora is set.
if config.train_text_encoder:
raise ValueError("'cache_text_encoder_outputs' and 'train_text_encoder' cannot both be True.")
# We use a temporary directory for the cache. The directory will automatically be cleaned up when
# tmp_text_encoder_output_cache_dir is destroyed.
tmp_text_encoder_output_cache_dir = tempfile.TemporaryDirectory()
text_encoder_output_cache_dir_name = tmp_text_encoder_output_cache_dir.name
if accelerator.is_local_main_process:
# Only the main process should populate the cache.
logger.info(f"Generating text encoder output cache ('{text_encoder_output_cache_dir_name}').")
text_encoder.to(accelerator.device, dtype=weight_dtype)
cache_text_encoder_outputs(text_encoder_output_cache_dir_name, config, tokenizer, text_encoder)
# Move the text_encoder back to the CPU, because it is not needed for training.
text_encoder.to("cpu")
accelerator.wait_for_everyone()
else:
text_encoder.to(accelerator.device, dtype=weight_dtype)
ref_text_encoder.to(accelerator.device, dtype=weight_dtype)
# Prepare VAE output cache.
vae_output_cache_dir_name = None
if config.cache_vae_outputs:
raise NotImplementedError("VAE caching is not implemented for Diffusion-DPO training yet.")
# # We use a temporary directory for the cache. The directory will automatically be cleaned up when
# # tmp_vae_output_cache_dir is destroyed.
# tmp_vae_output_cache_dir = tempfile.TemporaryDirectory()
# vae_output_cache_dir_name = tmp_vae_output_cache_dir.name
# if accelerator.is_local_main_process:
# # Only the main process should populate the cache.
# logger.info(f"Generating VAE output cache ('{vae_output_cache_dir_name}').")
# vae.to(accelerator.device, dtype=weight_dtype)
# data_loader = build_data_loader(
# data_loader_config=config.data_loader,
# batch_size=config.train_batch_size,
# shuffle=False,
# sequential_batching=True,
# )
# cache_vae_outputs(vae_output_cache_dir_name, data_loader, vae)
# # Move the VAE back to the CPU, because it is not needed for training.
# vae.to("cpu")
# accelerator.wait_for_everyone()
else:
vae.to(accelerator.device, dtype=weight_dtype)
unet.to(accelerator.device, dtype=weight_dtype)
ref_unet.to(accelerator.device, dtype=weight_dtype)
# Add LoRA layers to the models being trained.
trainable_param_groups = []
all_trainable_models: list[peft.PeftModel] = []
# Add LoRA layers to the model.
trainable_param_groups = []
if config.initial_lora is not None:
unet, text_encoder = load_sd_peft_checkpoint(
checkpoint_dir=config.initial_lora, unet=unet, text_encoder=text_encoder, is_trainable=True
)
ref_unet, ref_text_encoder = load_sd_peft_checkpoint(
checkpoint_dir=config.initial_lora, unet=ref_unet, text_encoder=ref_text_encoder, is_trainable=False
)
else:
if config.train_unet:
unet_lora_config = peft.LoraConfig(
r=config.lora_rank_dim,
# TODO(ryand): Diffusers uses lora_alpha=config.lora_rank_dim. Is that preferred?
lora_alpha=1.0,
target_modules=UNET_TARGET_MODULES,
)
unet = peft.get_peft_model(unet, unet_lora_config)
if config.train_text_encoder:
text_encoder_lora_config = peft.LoraConfig(
r=config.lora_rank_dim,
lora_alpha=1.0,
# init_lora_weights="gaussian",
target_modules=TEXT_ENCODER_TARGET_MODULES,
)
text_encoder = peft.get_peft_model(text_encoder, text_encoder_lora_config)
def prep_peft_model(model, lr: float | None = None):
if not isinstance(model, peft.PeftModel):
return False
model.print_trainable_parameters()
# Populate `trainable_param_groups`, to be passed to the optimizer.
param_group = {"params": list(filter(lambda p: p.requires_grad, model.parameters()))}
if lr is not None:
param_group["lr"] = lr
trainable_param_groups.append(param_group)
# Populate all_trainable_models.
all_trainable_models.append(model)
model.train()
return True
training_unet = prep_peft_model(unet, config.unet_learning_rate)
training_text_encoder = prep_peft_model(text_encoder, config.text_encoder_learning_rate)
# If mixed_precision is enabled, cast all trainable params to float32.
if config.mixed_precision != "no":
for trainable_model in all_trainable_models:
for param in trainable_model.parameters():
if param.requires_grad:
param.data = param.to(torch.float32)
if config.gradient_checkpointing:
# We want to enable gradient checkpointing in the UNet regardless of whether it is being trained.
unet.enable_gradient_checkpointing()
# unet must be in train() mode for gradient checkpointing to take effect.
# At the time of writing, the unet dropout probabilities default to 0, so putting the unet in train mode does
# not change its forward behavior.
unet.train()
if training_text_encoder:
text_encoder.gradient_checkpointing_enable()
# The text encoder must be in train() mode for gradient checkpointing to take effect. This should
# already be the case, since we are training the text_encoder, but we do it explicitly to make it clear
# that this is required.
# At the time of writing, the text encoder dropout probabilities default to 0, so putting the text
# encoders in train mode does not change their forward behavior.
text_encoder.train()
# Set requires_grad = True on the first parameters of the text encoders. Without this, the text encoder
# LoRA weights would have 0 gradients, and so would not get trained. Note that the set of
# trainable_param_groups has already been populated - the embeddings will not be trained.
text_encoder.text_model.embeddings.requires_grad_(True)
optimizer = initialize_optimizer(config.optimizer, trainable_param_groups)
data_loader = build_image_pair_preference_sd_dataloader(
config=config.data_loader,
batch_size=config.train_batch_size,
text_encoder_output_cache_dir=text_encoder_output_cache_dir_name,
text_encoder_cache_field_to_output_field={"text_encoder_output": "text_encoder_output"},
vae_output_cache_dir=vae_output_cache_dir_name,
shuffle=True,
)
# TODO(ryand): Test in a distributed training environment and more clearly document the rationale for scaling steps
# by the number of processes. This scaling logic was copied from the diffusers example training code, but it appears
# in many places so I don't know where it originated. Internally, accelerate makes one LR scheduler step per process
# (https://github.com/huggingface/accelerate/blame/49cb83a423f2946059117d8bb39b7c8747d29d80/src/accelerate/scheduler.py#L72-L82),
# so the scaling here simply reverses that behaviour.
lr_scheduler: torch.optim.lr_scheduler.LRScheduler = get_scheduler(
config.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=config.lr_warmup_steps * accelerator.num_processes,
num_training_steps=config.max_train_steps * accelerator.num_processes,
)
prepared_result: tuple[
UNet2DConditionModel | peft.PeftModel,
CLIPTextModel | peft.PeftModel,
torch.optim.Optimizer,
torch.utils.data.DataLoader,
torch.optim.lr_scheduler.LRScheduler,
] = accelerator.prepare(
unet,
text_encoder,
optimizer,
data_loader,
lr_scheduler,
# Disable automatic device placement for text_encoder if the text encoder outputs were cached.
device_placement=[True, not config.cache_text_encoder_outputs, True, True, True],
)
unet, text_encoder, optimizer, data_loader, lr_scheduler = prepared_result
# Calculate the number of epochs and total training steps. A "step" represents a single weight update operation
# (i.e. takes into account gradient accumulation steps).
# math.ceil(...) is used in calculating the num_steps_per_epoch, because by default an optimizer step is taken when
# the end of the dataloader is reached, even if gradient_accumulation_steps hasn't been reached.
num_steps_per_epoch = math.ceil(len(data_loader) / config.gradient_accumulation_steps)
num_train_epochs = math.ceil(config.max_train_steps / num_steps_per_epoch)
if accelerator.is_main_process:
accelerator.init_trackers("lora_training")
# Tensorboard uses markdown formatting, so we wrap the config json in a code block.
accelerator.log({"configuration": f"```json\n{json.dumps(config.dict(), indent=2, default=str)}\n```\n"})
checkpoint_tracker = CheckpointTracker(
base_dir=ckpt_dir,
prefix="checkpoint",
extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None,
max_checkpoints=config.max_checkpoints,
)
# Train!
total_batch_size = config.train_batch_size * accelerator.num_processes * config.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(data_loader)}")
logger.info(f" Instantaneous batch size per device = {config.train_batch_size}")
logger.info(f" Gradient accumulation steps = {config.gradient_accumulation_steps}")
logger.info(f" Parallel processes = {accelerator.num_processes}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Total optimization steps = {config.max_train_steps}")
global_step = 0
first_epoch = 0
completed_epochs = first_epoch
progress_bar = tqdm(
range(global_step, config.max_train_steps),
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)
progress_bar.set_description("Steps")
for epoch in range(first_epoch, num_train_epochs):
train_loss = 0.0
for data_batch_idx, data_batch in enumerate(data_loader):
with accelerator.accumulate(unet, text_encoder):
loss = train_forward_dpo(
config=config,
data_batch=data_batch,
vae=vae,
noise_scheduler=noise_scheduler,
tokenizer=tokenizer,
text_encoder=text_encoder,
unet=unet,
ref_text_encoder=ref_text_encoder,
ref_unet=ref_unet,
weight_dtype=weight_dtype,
)
# Gather the losses across all processes for logging (if we use distributed training).
# TODO(ryand): Test that this works properly with distributed training.
avg_loss = accelerator.gather(loss.repeat(config.train_batch_size)).mean()
train_loss += avg_loss.item() / config.gradient_accumulation_steps
# Backpropagate.
accelerator.backward(loss)
if accelerator.sync_gradients and config.max_grad_norm is not None:
params_to_clip = itertools.chain.from_iterable([m.parameters() for m in all_trainable_models])
accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Checks if the accelerator has performed an optimization step behind the scenes.
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
completed_epochs = epoch if (data_batch_idx + 1) < len(data_loader) else epoch + 1
log = {"train_loss": train_loss}
lrs = lr_scheduler.get_last_lr()
if training_unet:
# When training the UNet, it will always be the first parameter group.
log["lr/unet"] = float(lrs[0])
if config.optimizer.optimizer_type == "Prodigy":
log["lr/d*lr/unet"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
if training_text_encoder:
# When training the text encoder, it will always be the last parameter group.
log["lr/text_encoder"] = float(lrs[-1])
if config.optimizer.optimizer_type == "Prodigy":
log["lr/d*lr/text_encoder"] = optimizer.param_groups[-1]["d"] * optimizer.param_groups[-1]["lr"]
accelerator.log(log, step=global_step)
train_loss = 0.0
# global_step represents the *number of completed steps* at this point.
if config.save_every_n_steps is not None and global_step % config.save_every_n_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
_save_sd_lora_checkpoint(
epoch=completed_epochs,
step=global_step,
unet=accelerator.unwrap_model(unet) if training_unet else None,
text_encoder=accelerator.unwrap_model(text_encoder) if training_text_encoder else None,
logger=logger,
checkpoint_tracker=checkpoint_tracker,
lora_checkpoint_format=config.lora_checkpoint_format,
)
logs = {
"step_loss": loss.detach().item(),
"lr": lr_scheduler.get_last_lr()[0],
}
progress_bar.set_postfix(**logs)
if global_step >= config.max_train_steps:
break
# Save a checkpoint every n epochs.
if config.save_every_n_epochs is not None and completed_epochs % config.save_every_n_epochs == 0:
if accelerator.is_main_process:
accelerator.wait_for_everyone()
_save_sd_lora_checkpoint(
epoch=completed_epochs,
step=global_step,
unet=accelerator.unwrap_model(unet) if training_unet else None,
text_encoder=accelerator.unwrap_model(text_encoder) if training_text_encoder else None,
logger=logger,
checkpoint_tracker=checkpoint_tracker,
lora_checkpoint_format=config.lora_checkpoint_format,
)
# Generate validation images every n epochs.
if len(config.validation_prompts) > 0 and completed_epochs % config.validate_every_n_epochs == 0:
if accelerator.is_main_process:
generate_validation_images_sd(
epoch=completed_epochs,
step=global_step,
out_dir=out_dir,
accelerator=accelerator,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
noise_scheduler=noise_scheduler,
unet=unet,
config=config,
logger=logger,
)
accelerator.end_training()
================================================
FILE: src/invoke_training/pipelines/callbacks.py
================================================
from abc import ABC
from enum import Enum
class ModelType(Enum):
# At first glance, it feels like these model types should be further broken down into separate enums (e.g.
# base_model, model_type, checkpoint_format). But, I haven't yet come up with a taxonomy that feels sufficiently
# future-proof. So, for now, there is one enum for each file type that invoke-training can produce.
# A Flux LoRA model in PEFT format.
FLUX_LORA_PEFT = "FLUX_LORA_PEFT"
# A Flux LoRA model in Kohya format.
FLUX_LORA_KOHYA = "FLUX_LORA_KOHYA"
# A Stable Diffusion 1.x LoRA model in Kohya format.
SD1_LORA_KOHYA = "SD1_LORA_KOHYA"
# A Stable Diffusion 1.x LoRA model in PEFT format.
SD1_LORA_PEFT = "SD1_LORA_PEFT"
# A Stable Diffusion XL LoRA model in Kohya format.
SDXL_LORA_KOHYA = "SDXL_LORA_KOHYA"
# A Stable Diffusion XL LoRA model in PEFT format.
SDXL_LORA_PEFT = "SDXL_LORA_PEFT"
# A Stable Diffusion 1.x Textual Inversion model.
SD1_TEXTUAL_INVERSION = "SD1_TEXTUAL_INVERSION"
# A Stable Diffusion XL Textual Inversion model.
SDXL_TEXTUAL_INVERSION = "SDXL_TEXTUAL_INVERSION"
# A Stable Diffusion 1.x UNet checkpoint in diffusers format.
SD1_UNET_DIFFUSERS = "SD1_UNET_DIFFUSERS"
# A Stable Diffusion XL UNet checkpoint in diffusers format.
SDXL_UNET_DIFFUSERS = "SDXL_UNET_DIFFUSERS"
# A full Stable Diffusion XL checkpoint in diffusers format.
SDXL_FULL_DIFFUSERS = "SDXL_FULL_DIFFUSERS"
class ModelCheckpoint:
"""A single model checkpoint."""
def __init__(self, file_path: str, model_type: ModelType):
self.file_path = file_path
self.model_type = model_type
class TrainingCheckpoint:
"""A training checkpoint. May contain multiple model checkpoints if multiple models are being trained
simultaneously.
"""
def __init__(self, models: list[ModelCheckpoint], epoch: int, step: int):
self.models = models
self.epoch = epoch
self.step = step
class ValidationImage:
def __init__(self, file_path: str, prompt: str, image_idx: int):
"""A single validation image.
Args:
file_path (str): Path to the image file.
prompt (str): The prompt used to generate the image.
image_idx (int): The index of this image in the current validation set (i.e. in the set of images generated
with the same prompt at the same validation point).
"""
self.file_path = file_path
self.prompt = prompt
self.image_idx = image_idx
class ValidationImages:
def __init__(self, images: list[ValidationImage], epoch: int, step: int):
"""A collection of validation images.
Args:
images (list[ValidationImage]): The validation images.
epoch (int): The last completed epoch at the time that these images were generated.
step (int): The last completed training step at the time that these images were generated.
"""
self.images = images
self.epoch = epoch
self.step = step
class PipelineCallbacks(ABC):
def on_save_checkpoint(self, checkpoint: TrainingCheckpoint):
pass
def on_save_validation_images(self, images: ValidationImages):
pass
================================================
FILE: src/invoke_training/pipelines/flux/lora/__init__.py
================================================
================================================
FILE: src/invoke_training/pipelines/flux/lora/config.py
================================================
from typing import Annotated, Literal, Union
from pydantic import Field
from invoke_training._shared.flux.lora_checkpoint_utils import (
FLUX_TRANSFORMER_TARGET_MODULES,
TEXT_ENCODER_TARGET_MODULES,
)
from invoke_training.config.base_pipeline_config import BasePipelineConfig
from invoke_training.config.data.data_loader_config import ImageCaptionFluxDataLoaderConfig
from invoke_training.config.optimizer.optimizer_config import (
AdamOptimizerConfig,
ProdigyOptimizerConfig,
)
class FluxLoraConfig(BasePipelineConfig):
type: Literal["FLUX_LORA"] = "FLUX_LORA"
model: str = "black-forest-labs/FLUX.1-dev"
"""Name or path of the base model to train. Can be in diffusers format, or a single Flux.1-dev checkpoint
file. (E.g. 'black-forest-labs/FLUX.1-dev', '/path/to/flux.1-dev.safetensors', etc. )
"""
transformer_path: str | None = None
"""Path to the custom transformer .safetensors file. If not provided, the default black-forest-labs/FLUX.1-dev
transformer will be used.
"""
text_encoder_1_path: str | None = None
"""Path to the custom CLIP text encoder .safetensors file. If not provided, the default openai/clip-vit-base-patch32
text encoder will be used.
"""
text_encoder_2_path: str | None = None
"""Path to the custom T5 text encoder .safetensors file. If not provided, the default google/t5-v1_1-xl text encoder
will be used.
"""
lora_checkpoint_format: Literal["invoke_peft", "kohya"] = "kohya"
"""The format of the LoRA checkpoint to save. Choose between `invoke_peft` or `kohya`."""
train_transformer: bool = True
"""Whether to add LoRA layers to the FluxTransformer2DModel and train it.
"""
train_text_encoder: bool = False
"""Whether to add LoRA layers to the text encoder and train it.
"""
optimizer: AdamOptimizerConfig | ProdigyOptimizerConfig = AdamOptimizerConfig()
text_encoder_learning_rate: float | None = 1e-4
"""The learning rate to use for the text encoder model. If set, this overrides the optimizer's default learning
rate. Set to null or 0 to use the optimizer's default learning rate.
"""
transformer_learning_rate: float | None = 4e-4
"""The learning rate to use for the transformer model. If set, this overrides the optimizer's default learning
rate. Set to null or 0 to use the optimizer's default learning rate.
"""
lr_scheduler: Literal[
"linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"
] = "constant_with_warmup"
lr_warmup_steps: int = 10
"""The number of warmup steps in the learning rate scheduler. Only applied to schedulers that support warmup.
See lr_scheduler.
"""
min_snr_gamma: float | None = None
"""Min-SNR weighting for diffusion training was introduced in https://arxiv.org/abs/2303.09556. This strategy
improves the speed of training convergence by adjusting the weight of each sample.
`min_snr_gamma` acts like an an upper bound on the weight of samples with low noise levels.
If `None`, then Min-SNR weighting will not be applied. If enabled, the recommended value is `min_snr_gamma = 5.0`.
"""
lora_rank_dim: int = 4
"""The rank dimension to use for the LoRA layers. Increasing the rank dimension increases the model's expressivity,
but also increases the size of the generated LoRA model.
"""
flux_lora_target_modules: list[str] = FLUX_TRANSFORMER_TARGET_MODULES
"""The list of target modules to apply LoRA layers to in the FluxTransformer2DModel. The default list will produce a
highly expressive LoRA model.
For a smaller and less expressive LoRA model, the following list is recommended:
```python
flux_lora_target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
```
The list of target modules is passed to Hugging Face's PEFT library. See
[the docs](https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.target_modules) for
details.
"""
text_encoder_lora_target_modules: list[str] = TEXT_ENCODER_TARGET_MODULES
"""The list of target modules to apply LoRA layers to in the CLIP text encoder. The default list will produce a
highly expressive LoRA model.
For a smaller and less expressive LoRA model, the following list is recommended:
```python
text_encoder_lora_target_modules = ["fc1", "fc2", "q_proj", "k_proj", "v_proj", "out_proj"]
```
The list of target modules is passed to Hugging Face's PEFT library. See
[the docs](https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.target_modules) for
details.
"""
cache_text_encoder_outputs: bool = False
"""If True, the text encoder(s) will be applied to all of the captions in the dataset before starting training and
the results will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the
text encoders in VRAM), and speeds up training (don't have to run the text encoders for each training example).
This option can only be enabled if `train_text_encoder == False` and there are no caption augmentations being
applied.
"""
cache_vae_outputs: bool = False
"""If True, the VAE will be applied to all of the images in the dataset before starting training and the results
will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the VAE in VRAM), and
speeds up training (don't have to run the VAE encoding step). This option can only be enabled if all
non-deterministic image augmentations are disabled (i.e. center_crop=True, random_flip=False).
"""
enable_cpu_offload_during_validation: bool = False
"""If True, models will be kept in CPU memory and loaded into GPU memory one-by-one while generating validation
images. This reduces VRAM requirements at the cost of slower generation of validation images.
"""
gradient_accumulation_steps: int = 1
"""The number of gradient steps to accumulate before each weight update. This value is passed to Hugging Face
Accelerate. This is an alternative to increasing the batch size when training with limited VRAM.
"""
weight_dtype: Literal["float32", "float16", "bfloat16"] = "float16"
"""All weights (trainable and fixed) will be cast to this precision. Lower precision dtypes require less VRAM, and
result in faster training, but are more prone to issues with numerical stability.
Recommendations:
- `"float32"`: Use this mode if you have plenty of VRAM available.
- `"bfloat16"`: Use this mode if you have limited VRAM and a GPU that supports bfloat16.
- `"float16"`: Use this mode if you have limited VRAM and a GPU that does not support bfloat16.
See also [`mixed_precision`][invoke_training.pipelines.flux.lora.config.FluxLoraConfig.mixed_precision].
""" # noqa: E501
mixed_precision: Literal["no", "fp16", "bf16", "fp8"] = "no"
"""The mixed precision mode to use.
If mixed precision is enabled, then all non-trainable parameters will be cast to the specified `weight_dtype`, and
trainable parameters are kept in float32 precision to avoid issues with numerical stability.
This value is passed to Hugging Face Accelerate. See
[`accelerate.Accelerator.mixed_precision`](https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.mixed_precision)
for more details.
""" # noqa: E501
gradient_checkpointing: bool = False
"""Whether or not to use gradient checkpointing to save memory at the expense of a slower backward pass. Enabling
gradient checkpointing slows down training by ~20%.
"""
max_checkpoints: int | None = None
"""The maximum number of checkpoints to keep. New checkpoints will replace earlier checkpoints to stay under this
limit. Note that this limit is applied to 'step' and 'epoch' checkpoints separately.
"""
prediction_type: Literal["epsilon", "v_prediction"] | None = None
"""The prediction_type that will be used for training. Choose between 'epsilon' or 'v_prediction' or leave 'None'.
If 'None', the prediction type of the scheduler: `noise_scheduler.config.prediction_type` is used.
"""
max_grad_norm: float | None = None
"""Max gradient norm for clipping. Set to null or 0 for no clipping.
"""
validation_prompts: list[str] = []
"""A list of prompts that will be used to generate images throughout training for the purpose of tracking progress.
See also 'validate_every_n_epochs'.
"""
num_validation_images_per_prompt: int = 4
"""The number of validation images to generate for each prompt in 'validation_prompts'. Careful, validation can
become quite slow if this number is too large.
"""
train_batch_size: int = 1
"""The training batch size.
"""
use_masks: bool = False
"""If True, image masks will be applied to weight the loss during training. The dataset must contain masks for this
feature to be used.
"""
data_loader: Annotated[Union[ImageCaptionFluxDataLoaderConfig], Field(discriminator="type")]
timestep_sampler: Literal["shift", "uniform"] = "shift"
"""The timestep sampler to use. Choose between 'shift' or 'uniform'."""
discrete_flow_shift: float = 3.0
"""The shift parameter for the discrete flow. Only used if `timestep_sampler == "shift"`.
"""
sigmoid_scale: float = 1.0
"""The scale parameter for the sigmoid function. Only used if `timestep_sampler == "shift"`.
"""
lora_scale: float | None = 1.0
"""The scale parameter for the LoRA layers. If set, this overrides the optimizer's default learning rate.
"""
guidance_scale: float = 1.0
"""The guidance scale for the Flux model.
"""
train_transformer: bool = True
"""Whether to train the Flux transformer (FluxTransformer2DModel) model.
"""
clip_tokenizer_max_length: int = 77
"""The maximum length of the CLIP tokenizer. The maximum length of the CLIP tokenizer is 77.
"""
t5_tokenizer_max_length: int = 512
"""The maximum length of the T5 tokenizer. The maximum length of the T5 tokenizer is 512.
"""
================================================
FILE: src/invoke_training/pipelines/flux/lora/train.py
================================================
import itertools
import json
import logging
import math
import os
import tempfile
import time
from pathlib import Path
from typing import Literal, Optional, Union
import numpy as np
import peft
import torch
import torch.utils.data
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxTransformer2DModel
from diffusers.optimization import get_scheduler
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from peft import PeftModel
from PIL import Image
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from invoke_training._shared.accelerator.accelerator_utils import (
get_dtype_from_str,
initialize_accelerator,
initialize_logging,
)
from invoke_training._shared.checkpoints.checkpoint_tracker import CheckpointTracker
from invoke_training._shared.data.data_loaders.image_caption_flux_dataloader import build_image_caption_flux_dataloader
from invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache
from invoke_training._shared.flux.encoding_utils import encode_prompt
from invoke_training._shared.flux.lora_checkpoint_utils import (
save_flux_kohya_checkpoint,
save_flux_peft_checkpoint,
)
from invoke_training._shared.flux.model_loading_utils import load_models_flux
from invoke_training._shared.flux.validation import generate_validation_images_flux
from invoke_training._shared.optimizer.optimizer_utils import initialize_optimizer
from invoke_training._shared.stable_diffusion.tokenize_captions import tokenize_captions
from invoke_training.config.data.data_loader_config import ImageCaptionSDDataLoaderConfig
from invoke_training.pipelines.callbacks import ModelCheckpoint, ModelType, PipelineCallbacks, TrainingCheckpoint
from invoke_training.pipelines.flux.lora.config import FluxLoraConfig
def _save_flux_lora_checkpoint(
epoch: int,
step: int,
transformer: peft.PeftModel | None,
text_encoder_1: CLIPTextModel | None,
text_encoder_2: T5EncoderModel | None,
logger: logging.Logger,
checkpoint_tracker: CheckpointTracker,
callbacks: list[PipelineCallbacks] | None,
lora_checkpoint_format: Literal["invoke_peft", "kohya"] = "invoke_peft",
):
# Prune checkpoints and get new checkpoint path.
num_pruned = checkpoint_tracker.prune(1)
if num_pruned > 0:
logger.info(f"Pruned {num_pruned} checkpoint(s).")
save_path = checkpoint_tracker.get_path(epoch=epoch, step=step)
if lora_checkpoint_format == "invoke_peft":
model_type = ModelType.FLUX_LORA_PEFT
save_flux_peft_checkpoint(
Path(save_path), transformer=transformer, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2
)
elif lora_checkpoint_format == "kohya":
model_type = ModelType.FLUX_LORA_KOHYA
save_flux_kohya_checkpoint(
Path(save_path), transformer=transformer, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2
)
else:
raise ValueError(f"Unsupported lora_checkpoint_format: '{lora_checkpoint_format}'.")
if callbacks is not None:
for cb in callbacks:
cb.on_save_checkpoint(
TrainingCheckpoint(
models=[ModelCheckpoint(file_path=save_path, model_type=model_type)], epoch=epoch, step=step
)
)
def _build_data_loader(
data_loader_config: Union[ImageCaptionSDDataLoaderConfig],
batch_size: int,
use_masks: bool = False,
text_encoder_output_cache_dir: Optional[str] = None,
vae_output_cache_dir: Optional[str] = None,
shuffle: bool = True,
sequential_batching: bool = False,
) -> DataLoader:
if data_loader_config.type == "IMAGE_CAPTION_FLUX_DATA_LOADER":
return build_image_caption_flux_dataloader(
config=data_loader_config,
batch_size=batch_size,
use_masks=use_masks,
text_encoder_output_cache_dir=text_encoder_output_cache_dir,
text_encoder_cache_field_to_output_field={"text_encoder_output": "text_encoder_output"},
vae_output_cache_dir=vae_output_cache_dir,
shuffle=shuffle,
)
else:
raise ValueError(f"Unsupported data loader config type: '{data_loader_config.type}'.")
def cache_text_encoder_outputs(
cache_dir: str, config: FluxLoraConfig, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel
):
"""Run the text encoder on all captions in the dataset and cache the results to disk.
Args:
cache_dir (str): The directory where the results will be cached.
config (FluxLoraConfig): Training config.
tokenizer (CLIPTokenizer): The tokenizer.
text_encoder (CLIPTextModel): The text_encoder.
"""
data_loader = _build_data_loader(
data_loader_config=config.data_loader,
batch_size=config.train_batch_size,
shuffle=False,
sequential_batching=True,
)
cache = TensorDiskCache(cache_dir)
for data_batch in tqdm(data_loader):
caption_token_ids = tokenize_captions(tokenizer, data_batch["caption"]).to(text_encoder.device)
text_encoder_output_batch = text_encoder(caption_token_ids)[0]
# Split batch before caching.
for i in range(len(data_batch["id"])):
cache.save(data_batch["id"][i], {"text_encoder_output": text_encoder_output_batch[i]})
def cache_vae_outputs(cache_dir: str, data_loader: DataLoader, vae: AutoencoderKL):
"""Run the VAE on all images in the dataset and cache the results to disk."""
cache = TensorDiskCache(cache_dir)
for data_batch in tqdm(data_loader):
latents = vae.encode(data_batch["image"].to(device=vae.device, dtype=vae.dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Split batch before caching.
for i in range(len(data_batch["id"])):
data = {
"vae_output": latents[i],
"original_size_hw": data_batch["original_size_hw"][i],
"crop_top_left_yx": data_batch["crop_top_left_yx"][i],
}
if "mask" in data_batch:
data["mask"] = data_batch["mask"][i]
cache.save(data_batch["id"][i], data)
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = noise_scheduler.timesteps.to(device)
timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def get_noisy_latents(noise_scheduler: FlowMatchEulerDiscreteScheduler, latents: torch.Tensor, config: FluxLoraConfig):
"""
Generate random noise. Sample a random timestep from the distribution chosen by the config.
Linearly interpolate between the latents and the noise based on timestep.
See Section 3.1 of https://arxiv.org/pdf/2403.03206v1 for timestep sampling.
Args:
noise_scheduler (FlowMatchEulerDiscreteScheduler): The noise scheduler.
latents (torch.Tensor): The latents.
config (FluxLoraConfig): The config.
Returns:
torch.Tensor: The noisy latents.
"""
batch_size = latents.shape[0]
dtype = latents.dtype
device = latents.device
noise = torch.randn_like(latents)
if config.timestep_sampler == "shift":
shift = config.discrete_flow_shift
sigmas = torch.randn(batch_size, device=device)
sigmas = sigmas * config.sigmoid_scale # larger scale for more uniform sampling
sigmas = sigmas.sigmoid()
sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas)
timesteps = sigmas * noise_scheduler.config.num_train_timesteps
else:
u = torch.rand(size=(batch_size,), device="cpu")
indices = (u * noise_scheduler.config.num_train_timesteps).long()
timesteps = noise_scheduler.timesteps[indices].to(device=device)
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
sigmas = sigmas.view(-1, 1, 1, 1)
# Linearly interpolate between the latents and the noise.
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
return noisy_model_input.to(dtype), noise.to(dtype), timesteps.to(dtype), sigmas.to(dtype)
def decode_latents(vae: AutoencoderKL, latents: torch.Tensor):
latents = latents / vae.config.scaling_factor
image = vae.decode(latents).sample
# tensor to image
image = image.cpu().numpy()
image = (image * 255).astype(np.uint8)
image = Image.fromarray(image)
image.save("image.png")
return image
def train_forward( # noqa: C901
config: FluxLoraConfig,
data_batch: dict,
vae: AutoencoderKL,
noise_scheduler: FlowMatchEulerDiscreteScheduler,
tokenizer_1: CLIPTokenizer,
tokenizer_2: T5Tokenizer,
text_encoder_1: CLIPTextModel,
text_encoder_2: T5EncoderModel,
transformer: FluxTransformer2DModel | PeftModel,
weight_dtype: torch.dtype,
use_masks: bool = False,
min_snr_gamma: float | None = None,
logger: logging.Logger = None,
) -> torch.Tensor:
"""Run the forward training pass for a single data_batch.
Returns:
torch.Tensor: Loss
"""
# Convert images to latent space.
# The VAE output may have been cached and included in the data_batch. If not, we calculate it here.
latents = data_batch.get("vae_output", None)
if latents is None:
# Cast input image to same dtype as VAE
image = data_batch["image"].to(device=vae.device, dtype=vae.dtype)
latents = vae.encode(image).latent_dist.sample()
batch_size, num_channels, height, width = latents.shape
latents = latents * vae.config.scaling_factor
latents = FluxPipeline._pack_latents(latents, batch_size, num_channels, height, width)
else:
batch_size, num_channels, height, width = latents.shape
# Sample noise that we'll add to the latents.
latent_image_ids = FluxPipeline._prepare_latent_image_ids(
batch_size, height // 2, width // 2, latents.device, latents.dtype
)
# Add noise to the latents according to the noise magnitude at each timestep (this is the forward
# diffusion process).
noisy_latents, noise, timesteps, sigmas = get_noisy_latents(noise_scheduler, latents, config)
# Get the text embedding for conditioning.
# The text encoder output may have been cached and included in the data_batch. If not, we calculate it here.
if "prompt_embeds" in data_batch:
prompt_embeds = data_batch["prompt_embeds"]
pooled_prompt_embeds = data_batch["pooled_prompt_embeds"]
else:
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
prompt=data_batch["caption"],
prompt_2=data_batch.get("caption_2", None),
clip_tokenizer=tokenizer_1,
t5_tokenizer=tokenizer_2,
clip_text_encoder=text_encoder_1,
t5_text_encoder=text_encoder_2,
device=latents.device,
num_images_per_prompt=1,
lora_scale=config.lora_scale,
clip_tokenizer_max_length=config.clip_tokenizer_max_length,
t5_tokenizer_max_length=config.t5_tokenizer_max_length,
logger=logger,
)
guidance = torch.full((batch_size,), float(config.guidance_scale), device=latents.device)
model_pred = transformer(
hidden_states=noisy_latents[0],
timestep=timesteps / 1000,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
guidance=guidance,
txt_ids=text_ids,
img_ids=latent_image_ids,
return_dict=False,
)[0]
### Flow matching loss
# See here for more discussion:https://discuss.huggingface.co/t/meaning-of-vector-fields-in-flux-and-sd3-loss-function/106601
target = noise - latents
loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape))))
return loss.mean()
def train(config: FluxLoraConfig, callbacks: list[PipelineCallbacks] | None = None): # noqa: C901
# Create a timestamped directory for all outputs.
out_dir = os.path.join(config.base_output_dir, f"{time.time()}")
ckpt_dir = os.path.join(out_dir, "checkpoints")
os.makedirs(ckpt_dir)
accelerator = initialize_accelerator(
out_dir, config.gradient_accumulation_steps, config.mixed_precision, config.report_to
)
logger = initialize_logging(os.path.basename(__file__), accelerator)
# Set the accelerate seed.
if config.seed is not None:
set_seed(config.seed)
# Log the accelerator configuration from every process to help with debugging.
logger.info(accelerator.state, main_process_only=False)
logger.info("Starting LoRA Training.")
logger.info(f"Configuration:\n{json.dumps(config.dict(), indent=2, default=str)}")
logger.info(f"Output dir: '{out_dir}'")
# Write the configuration to disk.
with open(os.path.join(out_dir, "config.json"), "w") as f:
json.dump(config.dict(), f, indent=2, default=str)
weight_dtype = get_dtype_from_str(config.weight_dtype)
logger.info("Loading models.")
tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, transformer = load_models_flux(
model_name_or_path=config.model,
transformer_path=config.transformer_path,
text_encoder_1_path=config.text_encoder_1_path,
text_encoder_2_path=config.text_encoder_2_path,
dtype=weight_dtype,
logger=logger,
)
# Prepare text encoder output cache.
text_encoder_output_cache_dir_name = None
if config.cache_text_encoder_outputs:
# TODO(ryand): Think about how to better check if it is safe to cache the text encoder outputs. Currently, there
# are a number of configurations that would cause variation in the text encoder outputs and should not be used
# with caching.
# We use a temporary directory for the cache. The directory will automatically be cleaned up when
# tmp_text_encoder_output_cache_dir is destroyed.
tmp_text_encoder_output_cache_dir = tempfile.TemporaryDirectory()
text_encoder_output_cache_dir_name = tmp_text_encoder_output_cache_dir.name
if accelerator.is_local_main_process:
# Only the main process should populate the cache.
logger.info(f"Generating text encoder output cache ('{text_encoder_output_cache_dir_name}').")
text_encoder_1.to(accelerator.device, dtype=weight_dtype)
text_encoder_2.to(accelerator.device, dtype=weight_dtype)
# TODO(ryan): Move cache_text_encoder_outputs to a shared location so that it is not imported from another
# pipeline.
cache_text_encoder_outputs(
text_encoder_output_cache_dir_name, config, tokenizer_1, tokenizer_2, text_encoder_1, text_encoder_2
)
# Move the text_encoders back to the CPU, because they are not needed for training.
text_encoder_1.to("cpu")
text_encoder_2.to("cpu")
accelerator.wait_for_everyone()
else:
text_encoder_1.to(accelerator.device, dtype=weight_dtype)
text_encoder_2.to(accelerator.device, dtype=weight_dtype)
# Prepare VAE output cache.
# vae_output_cache_dir_name = None
if config.cache_vae_outputs:
if config.data_loader.random_flip:
raise ValueError("'cache_vae_outputs' cannot be True if 'random_flip' is True.")
if not config.data_loader.center_crop:
raise ValueError("'cache_vae_outputs' cannot be True if 'center_crop' is False.")
# We use a temporary directory for the cache. The directory will automatically be cleaned up when
# tmp_vae_output_cache_dir is destroyed.
tmp_vae_output_cache_dir = tempfile.TemporaryDirectory()
vae_output_cache_dir_name = tmp_vae_output_cache_dir.name
if accelerator.is_local_main_process:
# Only the main process should populate the cache.
logger.info(f"Generating VAE output cache ('{vae_output_cache_dir_name}').")
vae.to(accelerator.device, dtype=weight_dtype)
data_loader = _build_data_loader(
data_loader_config=config.data_loader,
batch_size=config.train_batch_size,
shuffle=False,
sequential_batching=True,
)
cache_vae_outputs(vae_output_cache_dir_name, data_loader, vae)
# Move the VAE back to the CPU, because it is not needed for training.
vae.to("cpu")
accelerator.wait_for_everyone()
else:
vae.to(accelerator.device, dtype=weight_dtype)
transformer.to(accelerator.device, dtype=weight_dtype)
# Add LoRA layers to the models being trained.
trainable_param_groups = []
all_trainable_models: list[peft.PeftModel] = []
def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = None) -> peft.PeftModel:
peft_model = peft.get_peft_model(model, lora_config)
peft_model.print_trainable_parameters()
# Populate `trainable_param_groups`, to be passed to the optimizer.
param_group = {"params": list(filter(lambda p: p.requires_grad, peft_model.parameters()))}
if lr is not None:
param_group["lr"] = lr
trainable_param_groups.append(param_group)
# Populate all_trainable_models.
all_trainable_models.append(peft_model)
peft_model.train()
return peft_model
# Add LoRA layers to the model.
if config.train_transformer:
transformer_lora_config = peft.LoraConfig(
r=config.lora_rank_dim,
# TODO(ryand): Diffusers uses lora_alpha=config.lora_rank_dim. Is that preferred?
lora_alpha=1.0,
target_modules=config.flux_lora_target_modules,
)
transformer = inject_lora_layers(transformer, transformer_lora_config, lr=config.transformer_learning_rate)
if config.train_text_encoder:
text_encoder_lora_config = peft.LoraConfig(
r=config.lora_rank_dim,
lora_alpha=1.0,
# init_lora_weights="gaussian",
target_modules=config.text_encoder_lora_target_modules,
)
text_encoder_1 = inject_lora_layers(
text_encoder_1, text_encoder_lora_config, lr=config.text_encoder_learning_rate
)
# Enable gradient checkpointing.
if config.gradient_checkpointing:
# We want to enable gradient checkpointing in the UNet regardless of whether it is being trained.
transformer.enable_gradient_checkpointing()
# unet must be in train() mode for gradient checkpointing to take effect.
# At the time of writing, the unet dropout probabilities default to 0, so putting the unet in train mode does
# not change its forward behavior.
transformer.train()
if config.train_text_encoder:
text_encoder_1.gradient_checkpointing_enable()
# The text encoders must be in train() mode for gradient checkpointing to take effect. This should
# already be the case, since we are training the text_encoders, be we do it explicitly to make it clear
# that this is required.
# At the time of writing, the text encoder dropout probabilities default to 0, so putting the text
# encoders in train mode does not change their forward behavior.
text_encoder_1.train()
# Set requires_grad = True on the first parameters of the text encoders. Without this, the text encoder
# LoRA weights would have 0 gradients, and so would not get trained. Note that the set of
# trainable_param_groups has already been populated - the embeddings will not be trained.
text_encoder_1.text_model.embeddings.requires_grad_(True)
optimizer = initialize_optimizer(config.optimizer, trainable_param_groups)
data_loader = _build_data_loader(
data_loader_config=config.data_loader,
batch_size=config.train_batch_size,
# text_encoder_output_cache_dir=text_encoder_output_cache_dir_name,
# vae_output_cache_dir=vae_output_cache_dir_name,
)
assert sum([config.max_train_steps is not None, config.max_train_epochs is not None]) == 1
assert sum([config.save_every_n_steps is not None, config.save_every_n_epochs is not None]) == 1
assert sum([config.validate_every_n_steps is not None, config.validate_every_n_epochs is not None]) == 1
# A "step" represents a single weight update operation (i.e. takes into account gradient accumulation steps).
# math.ceil(...) is used in calculating the num_steps_per_epoch, because by default an optimizer step is taken when
# the end of the dataloader is reached, even if gradient_accumulation_steps hasn't been reached.
num_steps_per_epoch = math.ceil(len(data_loader) / config.gradient_accumulation_steps)
num_train_steps = config.max_train_steps or config.max_train_epochs * num_steps_per_epoch
num_train_epochs = math.ceil(num_train_steps / num_steps_per_epoch)
# TODO(ryand): Test in a distributed training environment and more clearly document the rationale for scaling steps
# by the number of processes. This scaling logic was copied from the diffusers example training code, but it appears
# in many places so I don't know where it originated. Internally, accelerate makes one LR scheduler step per process
# (https://github.com/huggingface/accelerate/blame/49cb83a423f2946059117d8bb39b7c8747d29d80/src/accelerate/scheduler.py#L72-L82),
# so the scaling here simply reverses that behaviour.
lr_scheduler: torch.optim.lr_scheduler.LRScheduler = get_scheduler(
config.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=config.lr_warmup_steps * accelerator.num_processes,
num_training_steps=num_train_steps * accelerator.num_processes,
)
prepared_result: tuple[
FluxTransformer2DModel,
CLIPTextModel,
T5EncoderModel,
torch.optim.Optimizer,
torch.utils.data.DataLoader,
torch.optim.lr_scheduler.LRScheduler,
] = accelerator.prepare(
transformer,
text_encoder_1,
text_encoder_2,
optimizer,
data_loader,
lr_scheduler,
# Disable automatic device placement for text_encoder if the text encoder outputs were cached.
device_placement=[
True,
not config.cache_text_encoder_outputs,
not config.cache_text_encoder_outputs,
True,
True,
True,
],
)
transformer, text_encoder_1, text_encoder_2, optimizer, data_loader, lr_scheduler = prepared_result
if accelerator.is_main_process:
accelerator.init_trackers("lora_training")
# Tensorboard uses markdown formatting, so we wrap the config json in a code block.
accelerator.log({"configuration": f"```json\n{json.dumps(config.dict(), indent=2, default=str)}\n```\n"})
checkpoint_tracker = CheckpointTracker(
base_dir=ckpt_dir,
prefix="checkpoint",
max_checkpoints=config.max_checkpoints,
extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None,
)
# Train!
total_batch_size = config.train_batch_size * accelerator.num_processes * config.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num batches = {len(data_loader)}")
logger.info(f" Instantaneous batch size per device = {config.train_batch_size}")
logger.info(f" Gradient accumulation steps = {config.gradient_accumulation_steps}")
logger.info(f" Parallel processes = {accelerator.num_processes}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Total optimization steps = {num_train_steps}")
logger.info(f" Total epochs = {num_train_epochs}")
global_step = 0
first_epoch = 0
completed_epochs = 0
progress_bar = tqdm(
range(global_step, num_train_steps),
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)
progress_bar.set_description("Steps")
def save_checkpoint(num_completed_epochs: int, num_completed_steps: int):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
_save_flux_lora_checkpoint(
epoch=num_completed_epochs,
step=num_completed_steps,
transformer=transformer if config.train_transformer else None,
text_encoder_1=text_encoder_1 if config.train_text_encoder else None,
text_encoder_2=text_encoder_2 if config.train_text_encoder else None,
logger=logger,
checkpoint_tracker=checkpoint_tracker,
lora_checkpoint_format=config.lora_checkpoint_format,
callbacks=callbacks,
)
accelerator.wait_for_everyone()
def validate(num_completed_epochs: int, num_completed_steps: int):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
generate_validation_images_flux(
epoch=num_completed_epochs,
step=num_completed_steps,
out_dir=out_dir,
accelerator=accelerator,
vae=vae,
text_encoder_1=text_encoder_1,
text_encoder_2=text_encoder_2,
tokenizer_1=tokenizer_1,
tokenizer_2=tokenizer_2,
noise_scheduler=noise_scheduler,
transformer=transformer,
config=config,
logger=logger,
callbacks=callbacks,
)
accelerator.wait_for_everyone()
for epoch in range(first_epoch, num_train_epochs):
train_loss = 0.0
for data_batch_idx, data_batch in enumerate(data_loader):
# (Pdb) data_batch['image'].shape
# torch.Size([4, 3, 512, 512])
with accelerator.accumulate(transformer, text_encoder_1, text_encoder_2):
loss = train_forward(
config=config,
data_batch=data_batch,
vae=vae,
noise_scheduler=noise_scheduler,
tokenizer_1=tokenizer_1,
tokenizer_2=tokenizer_2,
text_encoder_1=text_encoder_1,
text_encoder_2=text_encoder_2,
transformer=transformer,
weight_dtype=weight_dtype,
min_snr_gamma=config.min_snr_gamma,
)
# Gather the losses across all processes for logging (if we use distributed training).
# TODO(ryand): Test that this works properly with distributed training.
avg_loss = accelerator.gather(loss.repeat(config.train_batch_size)).mean()
train_loss += avg_loss.item() / config.gradient_accumulation_steps
# Backpropagate.
accelerator.backward(loss)
if accelerator.sync_gradients and config.max_grad_norm is not None:
params_to_clip = itertools.chain.from_iterable([m.parameters() for m in all_trainable_models])
accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Checks if the accelerator has performed an optimization step behind the scenes.
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
completed_epochs = epoch if (data_batch_idx + 1) < len(data_loader) else epoch + 1
log = {"train_loss": train_loss}
lrs = lr_scheduler.get_last_lr()
if config.train_transformer:
# When training the UNet, it will always be the first parameter group.
log["lr/transformer"] = float(lrs[0])
if config.optimizer.optimizer_type == "Prodigy":
log["lr/d*lr/transformer"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
if config.train_text_encoder:
# When training the text encoder, it will always be the last parameter group.
log["lr/text_encoder"] = float(lrs[-1])
if config.optimizer.optimizer_type == "Prodigy":
log["lr/d*lr/text_encoder"] = optimizer.param_groups[-1]["d"] * optimizer.param_groups[-1]["lr"]
accelerator.log(log, step=global_step)
train_loss = 0.0
# global_step represents the *number of completed steps* at this point.
if config.save_every_n_steps is not None and global_step % config.save_every_n_steps == 0:
save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step)
if (
config.validate_every_n_steps is not None
and global_step % config.validate_every_n_steps == 0
and len(config.validation_prompts) > 0
):
validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step)
logs = {
"step_loss": loss.detach().item(),
"lr": lr_scheduler.get_last_lr()[0],
}
progress_bar.set_postfix(**logs)
if global_step >= num_train_steps:
break
# Save a checkpoint every n epochs.
if config.save_every_n_epochs is not None and completed_epochs % config.save_every_n_epochs == 0:
save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step)
# Generate validation images every n epochs.
if (
config.validate_every_n_epochs is not None
and completed_epochs % config.validate_every_n_epochs == 0
and len(config.validation_prompts) > 0
):
validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step)
accelerator.end_training()
================================================
FILE: src/invoke_training/pipelines/invoke_train.py
================================================
import os
from invoke_training.config.pipeline_config import PipelineConfig
from invoke_training.pipelines._experimental.sd_dpo_lora.train import train as train_sd_ddpo_lora
from invoke_training.pipelines.callbacks import PipelineCallbacks
from invoke_training.pipelines.flux.lora.train import train as train_flux_lora
from invoke_training.pipelines.stable_diffusion.lora.train import train as train_sd_lora
from invoke_training.pipelines.stable_diffusion.textual_inversion.train import train as train_sd_ti
from invoke_training.pipelines.stable_diffusion_xl.finetune.train import train as train_sdxl_finetune
from invoke_training.pipelines.stable_diffusion_xl.lora.train import train as train_sdxl_lora
from invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.train import (
train as train_sdxl_lora_and_ti,
)
from invoke_training.pipelines.stable_diffusion_xl.textual_inversion.train import train as train_sdxl_ti
def train(config: PipelineConfig, callbacks: list[PipelineCallbacks] | None = None):
"""This is the main entry point for all training pipelines."""
# Fail early if invalid callback types are provided, rather than failing later when the callbacks are used.
for cb in callbacks or []:
assert isinstance(cb, PipelineCallbacks)
if config.type == "FLUX_LORA":
# Disable tokenizer parallelism to avoid issues with tokenization
os.environ["TOKENIZERS_PARALLELISM"] = "false"
train_flux_lora(config, callbacks)
elif config.type == "SD_LORA":
train_sd_lora(config, callbacks)
elif config.type == "SDXL_LORA":
train_sdxl_lora(config, callbacks)
elif config.type == "SD_TEXTUAL_INVERSION":
train_sd_ti(config, callbacks)
elif config.type == "SDXL_TEXTUAL_INVERSION":
train_sdxl_ti(config, callbacks)
elif config.type == "SDXL_LORA_AND_TEXTUAL_INVERSION":
train_sdxl_lora_and_ti(config, callbacks)
elif config.type == "SDXL_FINETUNE":
train_sdxl_finetune(config, callbacks)
elif config.type == "SD_DIRECT_PREFERENCE_OPTIMIZATION_LORA":
print(f"Running EXPERIMENTAL pipeline: '{config.type}'.")
train_sd_ddpo_lora(config, callbacks)
else:
raise ValueError(f"Unexpected pipeline type: '{config.type}'.")
================================================
FILE: src/invoke_training/pipelines/stable_diffusion/__init__.py
================================================
================================================
FILE: src/invoke_training/pipelines/stable_diffusion/lora/__init__.py
================================================
================================================
FILE: src/invoke_training/pipelines/stable_diffusion/lora/config.py
================================================
from typing import Annotated, Literal, Union
from pydantic import Field, model_validator
from invoke_training._shared.stable_diffusion.lora_checkpoint_utils import (
TEXT_ENCODER_TARGET_MODULES,
UNET_TARGET_MODULES,
)
from invoke_training.config.base_pipeline_config import BasePipelineConfig
from invoke_training.config.data.data_loader_config import DreamboothSDDataLoaderConfig, ImageCaptionSDDataLoaderConfig
from invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig, ProdigyOptimizerConfig
class SdLoraConfig(BasePipelineConfig):
type: Literal["SD_LORA"] = "SD_LORA"
model: str = "runwayml/stable-diffusion-v1-5"
"""Name or path of the base model to train. Can be in diffusers format, or a single stable diffusion checkpoint
file. (E.g. 'runwayml/stable-diffusion-v1-5', '/path/to/realisticVisionV51_v51VAE.safetensors', etc. )
"""
hf_variant: str | None = "fp16"
"""The Hugging Face Hub model variant to use. Only applies if `model` is a Hugging Face Hub model name.
"""
# Note: Pydantic handles mutable default values well:
# https://docs.pydantic.dev/latest/concepts/models/#fields-with-non-hashable-default-values
base_embeddings: dict[str, str] = {}
"""A mapping of embedding tokens to trained embedding file paths. These embeddings will be applied to the base model
before training.
Example:
```
base_embeddings = {
"bruce_the_gnome": "/path/to/bruce_the_gnome.safetensors",
}
```
Consider also adding the embedding tokens to the `data_loader.caption_prefix` if they are not already present in the
dataset captions.
Note that the embeddings themselves are not fine-tuned further, but they will impact the LoRA model training if they
are referenced in the dataset captions. The list of embeddings provided here should be the same list used at
generation time with the resultant LoRA model.
"""
lora_checkpoint_format: Literal["invoke_peft", "kohya"] = "kohya"
"""The format of the LoRA checkpoint to save. Choose between `invoke_peft` or `kohya`."""
train_unet: bool = True
"""Whether to add LoRA layers to the UNet model and train it.
"""
train_text_encoder: bool = True
"""Whether to add LoRA layers to the text encoder and train it.
"""
optimizer: AdamOptimizerConfig | ProdigyOptimizerConfig = AdamOptimizerConfig()
text_encoder_learning_rate: float | None = None
"""The learning rate to use for the text encoder model. If set, this overrides the optimizer's default learning
rate. Set to null or 0 to use the optimizer's default learning rate.
"""
unet_learning_rate: float | None = None
"""The learning rate to use for the UNet model. If set, this overrides the optimizer's default learning rate.
Set to null or 0 to use the optimizer's default learning rate.
"""
lr_scheduler: Literal[
"linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"
] = "constant"
lr_warmup_steps: int = 0
"""The number of warmup steps in the learning rate scheduler. Only applied to schedulers that support warmup.
See lr_scheduler.
"""
min_snr_gamma: float | None = 5.0
"""Min-SNR weighting for diffusion training was introduced in https://arxiv.org/abs/2303.09556. This strategy
improves the speed of training convergence by adjusting the weight of each sample.
`min_snr_gamma` acts like an an upper bound on the weight of samples with low noise levels.
If `None`, then Min-SNR weighting will not be applied. If enabled, the recommended value is `min_snr_gamma = 5.0`.
"""
lora_rank_dim: int = 4
"""The rank dimension to use for the LoRA layers. Increasing the rank dimension increases the model's expressivity,
but also increases the size of the generated LoRA model.
"""
# The default list of target modules is based on
# https://github.com/huggingface/peft/blob/8665e2b5719faa4e4b91749ddec09442927b53e0/examples/stable_diffusion/train_dreambooth.py#L49C1-L65C87
unet_lora_target_modules: list[str] = UNET_TARGET_MODULES
"""The list of target modules to apply LoRA layers to in the UNet model. The default list will produce a highly
expressive LoRA model.
For a smaller and less expressive LoRA model, the following list is recommended:
```python
unet_lora_target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
```
The list of target modules is passed to Hugging Face's PEFT library. See
[the docs](https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.target_modules) for
details.
"""
text_encoder_lora_target_modules: list[str] = TEXT_ENCODER_TARGET_MODULES
"""The list of target modules to apply LoRA layers to in the text encoder models. The default list will produce a
highly expressive LoRA model.
For a smaller and less expressive LoRA model, the following list is recommended:
```python
text_encoder_lora_target_modules = ["fc1", "fc2", "q_proj", "k_proj", "v_proj", "out_proj"]
```
The list of target modules is passed to Hugging Face's PEFT library. See
[the docs](https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.target_modules) for
details.
"""
cache_text_encoder_outputs: bool = False
"""If True, the text encoder(s) will be applied to all of the captions in the dataset before starting training and
the results will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the
text encoders in VRAM), and speeds up training (don't have to run the text encoders for each training example).
This option can only be enabled if `train_text_encoder == False` and there are no caption augmentations being
applied.
"""
cache_vae_outputs: bool = False
"""If True, the VAE will be applied to all of the images in the dataset before starting training and the results
will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the VAE in VRAM), and
speeds up training (don't have to run the VAE encoding step). This option can only be enabled if all
non-deterministic image augmentations are disabled (i.e. center_crop=True, random_flip=False).
"""
enable_cpu_offload_during_validation: bool = False
"""If True, models will be kept in CPU memory and loaded into GPU memory one-by-one while generating validation
images. This reduces VRAM requirements at the cost of slower generation of validation images.
"""
gradient_accumulation_steps: int = 1
"""The number of gradient steps to accumulate before each weight update. This value is passed to Hugging Face
Accelerate. This is an alternative to increasing the batch size when training with limited VRAM.
"""
weight_dtype: Literal["float32", "float16", "bfloat16"] = "bfloat16"
"""All weights (trainable and fixed) will be cast to this precision. Lower precision dtypes require less VRAM, and
result in faster training, but are more prone to issues with numerical stability.
Recommendations:
- `"float32"`: Use this mode if you have plenty of VRAM available.
- `"bfloat16"`: Use this mode if you have limited VRAM and a GPU that supports bfloat16.
- `"float16"`: Use this mode if you have limited VRAM and a GPU that does not support bfloat16.
See also [`mixed_precision`][invoke_training.pipelines.stable_diffusion.lora.config.SdLoraConfig.mixed_precision].
""" # noqa: E501
mixed_precision: Literal["no", "fp16", "bf16", "fp8"] = "no"
"""The mixed precision mode to use.
If mixed precision is enabled, then all non-trainable parameters will be cast to the specified `weight_dtype`, and
trainable parameters are kept in float32 precision to avoid issues with numerical stability.
This value is passed to Hugging Face Accelerate. See
[`accelerate.Accelerator.mixed_precision`](https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.mixed_precision)
for more details.
""" # noqa: E501
xformers: bool = False
"""If true, use xformers for more efficient attention blocks.
"""
gradient_checkpointing: bool = False
"""Whether or not to use gradient checkpointing to save memory at the expense of a slower backward pass. Enabling
gradient checkpointing slows down training by ~20%.
"""
max_checkpoints: int | None = None
"""The maximum number of checkpoints to keep. New checkpoints will replace earlier checkpoints to stay under this
limit. Note that this limit is applied to 'step' and 'epoch' checkpoints separately.
"""
prediction_type: Literal["epsilon", "v_prediction"] | None = None
"""The prediction_type that will be used for training. Choose between 'epsilon' or 'v_prediction' or leave 'None'.
If 'None', the prediction type of the scheduler: `noise_scheduler.config.prediction_type` is used.
"""
max_grad_norm: float | None = None
"""Max gradient norm for clipping. Set to null or 0 for no clipping.
"""
validation_prompts: list[str] = []
"""A list of prompts that will be used to generate images throughout training for the purpose of tracking progress.
See also 'validate_every_n_epochs'.
"""
negative_validation_prompts: list[str] | None = None
"""A list of negative prompts that will be applied when generating validation images. If set, this list should have
the same length as 'validation_prompts'.
"""
num_validation_images_per_prompt: int = 4
"""The number of validation images to generate for each prompt in 'validation_prompts'. Careful, validation can
become quite slow if this number is too large.
"""
train_batch_size: int = 4
"""The training batch size.
"""
use_masks: bool = False
"""If True, image masks will be applied to weight the loss during training. The dataset must contain masks for this
feature to be used.
"""
data_loader: Annotated[
Union[ImageCaptionSDDataLoaderConfig, DreamboothSDDataLoaderConfig], Field(discriminator="type")
]
@model_validator(mode="after")
def check_validation_prompts(self):
if self.negative_validation_prompts is not None and len(self.negative_validation_prompts) != len(
self.validation_prompts
):
raise ValueError(
f"The number of validation_prompts ({len(self.validation_prompts)}) must match the number of "
f"negative_validation_prompts ({len(self.negative_validation_prompts)})."
)
return self
================================================
FILE: src/invoke_training/pipelines/stable_diffusion/lora/train.py
================================================
import itertools
import json
import logging
import math
import os
import tempfile
import time
from pathlib import Path
from typing import Literal, Optional, Union
import peft
import torch
import torch.utils.data
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from invoke_training._shared.accelerator.accelerator_utils import (
get_dtype_from_str,
initialize_accelerator,
initialize_logging,
)
from invoke_training._shared.checkpoints.checkpoint_tracker import CheckpointTracker
from invoke_training._shared.data.data_loaders.dreambooth_sd_dataloader import build_dreambooth_sd_dataloader
from invoke_training._shared.data.data_loaders.image_caption_sd_dataloader import build_image_caption_sd_dataloader
from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import log_aspect_ratio_buckets
from invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache
from invoke_training._shared.optimizer.optimizer_utils import initialize_optimizer
from invoke_training._shared.stable_diffusion.lora_checkpoint_utils import (
save_sd_kohya_checkpoint,
save_sd_peft_checkpoint,
)
from invoke_training._shared.stable_diffusion.min_snr_weighting import compute_snr
from invoke_training._shared.stable_diffusion.model_loading_utils import load_models_sd
from invoke_training._shared.stable_diffusion.tokenize_captions import tokenize_captions
from invoke_training._shared.stable_diffusion.validation import generate_validation_images_sd
from invoke_training._shared.utils.import_xformers import import_xformers
from invoke_training.config.data.data_loader_config import DreamboothSDDataLoaderConfig, ImageCaptionSDDataLoaderConfig
from invoke_training.pipelines.callbacks import ModelCheckpoint, ModelType, PipelineCallbacks, TrainingCheckpoint
from invoke_training.pipelines.stable_diffusion.lora.config import SdLoraConfig
def _save_sd_lora_checkpoint(
epoch: int,
step: int,
unet: peft.PeftModel | None,
text_encoder: peft.PeftModel | None,
logger: logging.Logger,
checkpoint_tracker: CheckpointTracker,
lora_checkpoint_format: Literal["invoke_peft", "kohya"],
callbacks: list[PipelineCallbacks] | None,
):
# Prune checkpoints and get new checkpoint path.
num_pruned = checkpoint_tracker.prune(1)
if num_pruned > 0:
logger.info(f"Pruned {num_pruned} checkpoint(s).")
save_path = checkpoint_tracker.get_path(epoch=epoch, step=step)
if lora_checkpoint_format == "invoke_peft":
model_type = ModelType.SD1_LORA_PEFT
save_sd_peft_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder)
elif lora_checkpoint_format == "kohya":
model_type = ModelType.SD1_LORA_KOHYA
save_sd_kohya_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder)
else:
raise ValueError(f"Unsupported lora_checkpoint_format: '{lora_checkpoint_format}'.")
if callbacks is not None:
for cb in callbacks:
cb.on_save_checkpoint(
TrainingCheckpoint(
models=[ModelCheckpoint(file_path=save_path, model_type=model_type)], epoch=epoch, step=step
)
)
def _build_data_loader(
data_loader_config: Union[ImageCaptionSDDataLoaderConfig, DreamboothSDDataLoaderConfig],
batch_size: int,
use_masks: bool = False,
text_encoder_output_cache_dir: Optional[str] = None,
vae_output_cache_dir: Optional[str] = None,
shuffle: bool = True,
sequential_batching: bool = False,
) -> DataLoader:
if data_loader_config.type == "IMAGE_CAPTION_SD_DATA_LOADER":
return build_image_caption_sd_dataloader(
config=data_loader_config,
batch_size=batch_size,
use_masks=use_masks,
text_encoder_output_cache_dir=text_encoder_output_cache_dir,
text_encoder_cache_field_to_output_field={"text_encoder_output": "text_encoder_output"},
vae_output_cache_dir=vae_output_cache_dir,
shuffle=shuffle,
)
elif data_loader_config.type == "DREAMBOOTH_SD_DATA_LOADER":
if use_masks:
raise NotImplementedError("Masks are not yet supported for DreamBooth data loaders.")
return build_dreambooth_sd_dataloader(
config=data_loader_config,
batch_size=batch_size,
text_encoder_output_cache_dir=text_encoder_output_cache_dir,
text_encoder_cache_field_to_output_field={"text_encoder_output": "text_encoder_output"},
vae_output_cache_dir=vae_output_cache_dir,
shuffle=shuffle,
sequential_batching=sequential_batching,
)
else:
raise ValueError(f"Unsupported data loader config type: '{data_loader_config.type}'.")
def cache_text_encoder_outputs(
cache_dir: str, config: SdLoraConfig, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel
):
"""Run the text encoder on all captions in the dataset and cache the results to disk.
Args:
cache_dir (str): The directory where the results will be cached.
config (SdLoraConfig): Training config.
tokenizer (CLIPTokenizer): The tokenizer.
text_encoder (CLIPTextModel): The text_encoder.
"""
data_loader = _build_data_loader(
data_loader_config=config.data_loader,
batch_size=config.train_batch_size,
shuffle=False,
sequential_batching=True,
)
cache = TensorDiskCache(cache_dir)
for data_batch in tqdm(data_loader):
caption_token_ids = tokenize_captions(tokenizer, data_batch["caption"]).to(text_encoder.device)
text_encoder_output_batch = text_encoder(caption_token_ids)[0]
# Split batch before caching.
for i in range(len(data_batch["id"])):
cache.save(data_batch["id"][i], {"text_encoder_output": text_encoder_output_batch[i]})
def cache_vae_outputs(cache_dir: str, data_loader: DataLoader, vae: AutoencoderKL):
"""Run the VAE on all images in the dataset and cache the results to disk."""
cache = TensorDiskCache(cache_dir)
for data_batch in tqdm(data_loader):
latents = vae.encode(data_batch["image"].to(device=vae.device, dtype=vae.dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Split batch before caching.
for i in range(len(data_batch["id"])):
data = {
"vae_output": latents[i],
"original_size_hw": data_batch["original_size_hw"][i],
"crop_top_left_yx": data_batch["crop_top_left_yx"][i],
}
if "mask" in data_batch:
data["mask"] = data_batch["mask"][i]
cache.save(data_batch["id"][i], data)
def train_forward( # noqa: C901
config: SdLoraConfig,
data_batch: dict,
vae: AutoencoderKL,
noise_scheduler: DDPMScheduler,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
unet: UNet2DConditionModel,
weight_dtype: torch.dtype,
use_masks: bool = False,
min_snr_gamma: float | None = None,
) -> torch.Tensor:
"""Run the forward training pass for a single data_batch.
Returns:
torch.Tensor: Loss
"""
# Convert images to latent space.
# The VAE output may have been cached and included in the data_batch. If not, we calculate it here.
latents = data_batch.get("vae_output", None)
if latents is None:
latents = vae.encode(data_batch["image"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents.
noise = torch.randn_like(latents)
batch_size = latents.shape[0]
# Sample a random timestep for each image.
timesteps = torch.randint(
0,
noise_scheduler.config.num_train_timesteps,
(batch_size,),
device=latents.device,
)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep (this is the forward
# diffusion process).
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning.
# The text_encoder_output may have been cached and included in the data_batch. If not, we calculate it here.
encoder_hidden_states = data_batch.get("text_encoder_output", None)
if encoder_hidden_states is None:
caption_token_ids = tokenize_captions(tokenizer, data_batch["caption"]).to(text_encoder.device)
encoder_hidden_states = text_encoder(caption_token_ids)[0].to(dtype=weight_dtype)
# Get the target for loss depending on the prediction type.
if config.prediction_type is not None:
# Set the prediction_type of scheduler if it's defined in config.
noise_scheduler.register_to_config(prediction_type=config.prediction_type)
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# Predict the noise residual.
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
min_snr_weights = None
if min_snr_gamma is not None:
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps)
# Note: We divide by snr here per Section 4.2 of the paper, since we are predicting the noise instead of x_0.
# w_t = min(1, SNR(t)) / SNR(t)
min_snr_weights = torch.clamp(snr, max=min_snr_gamma) / snr
if noise_scheduler.config.prediction_type == "epsilon":
pass
elif noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
min_snr_weights = min_snr_weights + 1
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction="none")
if use_masks:
# TODO(ryand): As a future performance optimization, we may want to do this resizing in the dataloader.
mask = data_batch["mask"].to(dtype=loss.dtype, device=loss.device)
_, _, latent_h, latent_w = loss.shape
mask = torch.nn.functional.interpolate(mask, size=(latent_h, latent_w), mode="nearest")
loss = loss * mask
# Mean-reduce the loss along all dimensions except for the batch dimension.
loss = loss.mean(dim=list(range(1, len(loss.shape))))
# Apply min_snr_weights.
if min_snr_weights is not None:
loss = loss * min_snr_weights
# Apply per-example loss weights.
if "loss_weight" in data_batch:
loss = loss * data_batch["loss_weight"]
return loss.mean()
def train(config: SdLoraConfig, callbacks: list[PipelineCallbacks] | None = None): # noqa: C901
# Give a clear error message if an unsupported base model was chosen.
# TODO(ryan): Update this check to work with single-file SD checkpoints.
# check_base_model_version(
# {BaseModelVersionEnum.STABLE_DIFFUSION_V1, BaseModelVersionEnum.STABLE_DIFFUSION_V2},
# config.model,
# local_files_only=False,
# )
# Create a timestamped directory for all outputs.
out_dir = os.path.join(config.base_output_dir, f"{time.time()}")
ckpt_dir = os.path.join(out_dir, "checkpoints")
os.makedirs(ckpt_dir)
accelerator = initialize_accelerator(
out_dir, config.gradient_accumulation_steps, config.mixed_precision, config.report_to
)
logger = initialize_logging(os.path.basename(__file__), accelerator)
# Set the accelerate seed.
if config.seed is not None:
set_seed(config.seed)
# Log the accelerator configuration from every process to help with debugging.
logger.info(accelerator.state, main_process_only=False)
logger.info("Starting LoRA Training.")
logger.info(f"Configuration:\n{json.dumps(config.dict(), indent=2, default=str)}")
logger.info(f"Output dir: '{out_dir}'")
# Write the configuration to disk.
with open(os.path.join(out_dir, "config.json"), "w") as f:
json.dump(config.dict(), f, indent=2, default=str)
weight_dtype = get_dtype_from_str(config.weight_dtype)
logger.info("Loading models.")
tokenizer, noise_scheduler, text_encoder, vae, unet = load_models_sd(
logger=logger,
model_name_or_path=config.model,
hf_variant=config.hf_variant,
base_embeddings=config.base_embeddings,
dtype=weight_dtype,
)
if config.xformers:
import_xformers()
# TODO(ryand): There is a known issue if xformers is enabled when training in mixed precision where xformers
# will fail because Q, K, V have different dtypes.
unet.enable_xformers_memory_efficient_attention()
vae.enable_xformers_memory_efficient_attention()
# Prepare text encoder output cache.
text_encoder_output_cache_dir_name = None
if config.cache_text_encoder_outputs:
# TODO(ryand): Think about how to better check if it is safe to cache the text encoder outputs. Currently, there
# are a number of configurations that would cause variation in the text encoder outputs and should not be used
# with caching.
if config.train_text_encoder:
raise ValueError("'cache_text_encoder_outputs' and 'train_text_encoder' cannot both be True.")
# We use a temporary directory for the cache. The directory will automatically be cleaned up when
# tmp_text_encoder_output_cache_dir is destroyed.
tmp_text_encoder_output_cache_dir = tempfile.TemporaryDirectory()
text_encoder_output_cache_dir_name = tmp_text_encoder_output_cache_dir.name
if accelerator.is_local_main_process:
# Only the main process should populate the cache.
logger.info(f"Generating text encoder output cache ('{text_encoder_output_cache_dir_name}').")
text_encoder.to(accelerator.device, dtype=weight_dtype)
cache_text_encoder_outputs(text_encoder_output_cache_dir_name, config, tokenizer, text_encoder)
# Move the text_encoder back to the CPU, because it is not needed for training.
text_encoder.to("cpu")
accelerator.wait_for_everyone()
else:
text_encoder.to(accelerator.device, dtype=weight_dtype)
# Prepare VAE output cache.
vae_output_cache_dir_name = None
if config.cache_vae_outputs:
if config.data_loader.random_flip:
raise ValueError("'cache_vae_outputs' cannot be True if 'random_flip' is True.")
if not config.data_loader.center_crop:
raise ValueError("'cache_vae_outputs' cannot be True if 'center_crop' is False.")
# We use a temporary directory for the cache. The directory will automatically be cleaned up when
# tmp_vae_output_cache_dir is destroyed.
tmp_vae_output_cache_dir = tempfile.TemporaryDirectory()
vae_output_cache_dir_name = tmp_vae_output_cache_dir.name
if accelerator.is_local_main_process:
# Only the main process should populate the cache.
logger.info(f"Generating VAE output cache ('{vae_output_cache_dir_name}').")
vae.to(accelerator.device, dtype=weight_dtype)
data_loader = _build_data_loader(
data_loader_config=config.data_loader,
batch_size=config.train_batch_size,
use_masks=config.use_masks,
shuffle=False,
sequential_batching=True,
)
cache_vae_outputs(vae_output_cache_dir_name, data_loader, vae)
# Move the VAE back to the CPU, because it is not needed for training.
vae.to("cpu")
accelerator.wait_for_everyone()
else:
vae.to(accelerator.device, dtype=weight_dtype)
unet.to(accelerator.device, dtype=weight_dtype)
# Add LoRA layers to the models being trained.
trainable_param_groups = []
all_trainable_models: list[peft.PeftModel] = []
def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = None) -> peft.PeftModel:
peft_model = peft.get_peft_model(model, lora_config)
peft_model.print_trainable_parameters()
# Populate `trainable_param_groups`, to be passed to the optimizer.
param_group = {"params": list(filter(lambda p: p.requires_grad, peft_model.parameters()))}
if lr is not None:
param_group["lr"] = lr
trainable_param_groups.append(param_group)
# Populate all_trainable_models.
all_trainable_models.append(peft_model)
peft_model.train()
return peft_model
# Add LoRA layers to the model.
if config.train_unet:
unet_lora_config = peft.LoraConfig(
r=config.lora_rank_dim,
# TODO(ryand): Diffusers uses lora_alpha=config.lora_rank_dim. Is that preferred?
lora_alpha=1.0,
target_modules=config.unet_lora_target_modules,
)
unet = inject_lora_layers(unet, unet_lora_config, lr=config.unet_learning_rate)
if config.train_text_encoder:
text_encoder_lora_config = peft.LoraConfig(
r=config.lora_rank_dim,
lora_alpha=1.0,
# init_lora_weights="gaussian",
target_modules=config.text_encoder_lora_target_modules,
)
text_encoder = inject_lora_layers(text_encoder, text_encoder_lora_config, lr=config.text_encoder_learning_rate)
# If mixed_precision is enabled, cast all trainable params to float32.
if config.mixed_precision != "no":
for trainable_model in all_trainable_models:
for param in trainable_model.parameters():
if param.requires_grad:
param.data = param.to(torch.float32)
if config.gradient_checkpointing:
# We want to enable gradient checkpointing in the UNet regardless of whether it is being trained.
unet.enable_gradient_checkpointing()
# unet must be in train() mode for gradient checkpointing to take effect.
# At the time of writing, the unet dropout probabilities default to 0, so putting the unet in train mode does
# not change its forward behavior.
unet.train()
if config.train_text_encoder:
text_encoder.gradient_checkpointing_enable()
# The text encoder must be in train() mode for gradient checkpointing to take effect. This should
# already be the case, since we are training the text_encoder, but we do it explicitly to make it clear
# that this is required.
# At the time of writing, the text encoder dropout probabilities default to 0, so putting the text
# encoders in train mode does not change their forward behavior.
text_encoder.train()
# Set requires_grad = True on the first parameters of the text encoders. Without this, the text encoder
# LoRA weights would have 0 gradients, and so would not get trained. Note that the set of
# trainable_param_groups has already been populated - the embeddings will not be trained.
text_encoder.text_model.embeddings.requires_grad_(True)
optimizer = initialize_optimizer(config.optimizer, trainable_param_groups)
data_loader = _build_data_loader(
data_loader_config=config.data_loader,
batch_size=config.train_batch_size,
use_masks=config.use_masks,
text_encoder_output_cache_dir=text_encoder_output_cache_dir_name,
vae_output_cache_dir=vae_output_cache_dir_name,
)
log_aspect_ratio_buckets(logger=logger, batch_sampler=data_loader.batch_sampler)
assert sum([config.max_train_steps is not None, config.max_train_epochs is not None]) == 1
assert sum([config.save_every_n_steps is not None, config.save_every_n_epochs is not None]) == 1
assert sum([config.validate_every_n_steps is not None, config.validate_every_n_epochs is not None]) == 1
# A "step" represents a single weight update operation (i.e. takes into account gradient accumulation steps).
# math.ceil(...) is used in calculating the num_steps_per_epoch, because by default an optimizer step is taken when
# the end of the dataloader is reached, even if gradient_accumulation_steps hasn't been reached.
num_steps_per_epoch = math.ceil(len(data_loader) / config.gradient_accumulation_steps)
num_train_steps = config.max_train_steps or config.max_train_epochs * num_steps_per_epoch
num_train_epochs = math.ceil(num_train_steps / num_steps_per_epoch)
# TODO(ryand): Test in a distributed training environment and more clearly document the rationale for scaling steps
# by the number of processes. This scaling logic was copied from the diffusers example training code, but it appears
# in many places so I don't know where it originated. Internally, accelerate makes one LR scheduler step per process
# (https://github.com/huggingface/accelerate/blame/49cb83a423f2946059117d8bb39b7c8747d29d80/src/accelerate/scheduler.py#L72-L82),
# so the scaling here simply reverses that behaviour.
lr_scheduler: torch.optim.lr_scheduler.LRScheduler = get_scheduler(
config.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=config.lr_warmup_steps * accelerator.num_processes,
num_training_steps=num_train_steps * accelerator.num_processes,
)
prepared_result: tuple[
UNet2DConditionModel,
CLIPTextModel,
torch.optim.Optimizer,
torch.utils.data.DataLoader,
torch.optim.lr_scheduler.LRScheduler,
] = accelerator.prepare(
unet,
text_encoder,
optimizer,
data_loader,
lr_scheduler,
# Disable automatic device placement for text_encoder if the text encoder outputs were cached.
device_placement=[True, not config.cache_text_encoder_outputs, True, True, True],
)
unet, text_encoder, optimizer, data_loader, lr_scheduler = prepared_result
if accelerator.is_main_process:
accelerator.init_trackers("lora_training")
# Tensorboard uses markdown formatting, so we wrap the config json in a code block.
accelerator.log({"configuration": f"```json\n{json.dumps(config.dict(), indent=2, default=str)}\n```\n"})
checkpoint_tracker = CheckpointTracker(
base_dir=ckpt_dir,
prefix="checkpoint",
max_checkpoints=config.max_checkpoints,
extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None,
)
# Train!
total_batch_size = config.train_batch_size * accelerator.num_processes * config.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num batches = {len(data_loader)}")
logger.info(f" Instantaneous batch size per device = {config.train_batch_size}")
logger.info(f" Gradient accumulation steps = {config.gradient_accumulation_steps}")
logger.info(f" Parallel processes = {accelerator.num_processes}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Total optimization steps = {num_train_steps}")
logger.info(f" Total epochs = {num_train_epochs}")
global_step = 0
first_epoch = 0
completed_epochs = 0
progress_bar = tqdm(
range(global_step, num_train_steps),
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)
progress_bar.set_description("Steps")
def save_checkpoint(num_completed_epochs: int, num_completed_steps: int):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
_save_sd_lora_checkpoint(
epoch=num_completed_epochs,
step=num_completed_steps,
unet=accelerator.unwrap_model(unet) if config.train_unet else None,
text_encoder=accelerator.unwrap_model(text_encoder) if config.train_text_encoder else None,
logger=logger,
checkpoint_tracker=checkpoint_tracker,
lora_checkpoint_format=config.lora_checkpoint_format,
callbacks=callbacks,
)
accelerator.wait_for_everyone()
def validate(num_completed_epochs: int, num_completed_steps: int):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
generate_validation_images_sd(
epoch=num_completed_epochs,
step=num_completed_steps,
out_dir=out_dir,
accelerator=accelerator,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
noise_scheduler=noise_scheduler,
unet=unet,
config=config,
logger=logger,
callbacks=callbacks,
)
accelerator.wait_for_everyone()
for epoch in range(first_epoch, num_train_epochs):
train_loss = 0.0
for data_batch_idx, data_batch in enumerate(data_loader):
with accelerator.accumulate(unet, text_encoder):
loss = train_forward(
config=config,
data_batch=data_batch,
vae=vae,
noise_scheduler=noise_scheduler,
tokenizer=tokenizer,
text_encoder=text_encoder,
unet=unet,
weight_dtype=weight_dtype,
use_masks=config.use_masks,
min_snr_gamma=config.min_snr_gamma,
)
# Gather the losses across all processes for logging (if we use distributed training).
# TODO(ryand): Test that this works properly with distributed training.
avg_loss = accelerator.gather(loss.repeat(config.train_batch_size)).mean()
train_loss += avg_loss.item() / config.gradient_accumulation_steps
# Backpropagate.
accelerator.backward(loss)
if accelerator.sync_gradients and config.max_grad_norm is not None:
params_to_clip = itertools.chain.from_iterable([m.parameters() for m in all_trainable_models])
accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Checks if the accelerator has performed an optimization step behind the scenes.
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
completed_epochs = epoch if (data_batch_idx + 1) < len(data_loader) else epoch + 1
log = {"train_loss": train_loss}
lrs = lr_scheduler.get_last_lr()
if config.train_unet:
# When training the UNet, it will always be the first parameter group.
log["lr/unet"] = float(lrs[0])
if config.optimizer.optimizer_type == "Prodigy":
log["lr/d*lr/unet"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
if config.train_text_encoder:
# When training the text encoder, it will always be the last parameter group.
log["lr/text_encoder"] = float(lrs[-1])
if config.optimizer.optimizer_type == "Prodigy":
log["lr/d*lr/text_encoder"] = optimizer.param_groups[-1]["d"] * optimizer.param_groups[-1]["lr"]
accelerator.log(log, step=global_step)
train_loss = 0.0
# global_step represents the *number of completed steps* at this point.
if config.save_every_n_steps is not None and global_step % config.save_every_n_steps == 0:
save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step)
if (
config.validate_every_n_steps is not None
and global_step % config.validate_every_n_steps == 0
and len(config.validation_prompts) > 0
):
validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step)
logs = {
"step_loss": loss.detach().item(),
"lr": lr_scheduler.get_last_lr()[0],
}
progress_bar.set_postfix(**logs)
if global_step >= num_train_steps:
break
# Save a checkpoint every n epochs.
if config.save_every_n_epochs is not None and completed_epochs % config.save_every_n_epochs == 0:
save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step)
# Generate validation images every n epochs.
if (
config.validate_every_n_epochs is not None
and completed_epochs % config.validate_every_n_epochs == 0
and len(config.validation_prompts) > 0
):
validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step)
accelerator.end_training()
================================================
FILE: src/invoke_training/pipelines/stable_diffusion/textual_inversion/__init__.py
================================================
================================================
FILE: src/invoke_training/pipelines/stable_diffusion/textual_inversion/config.py
================================================
from typing import Literal
from pydantic import model_validator
from invoke_training.config.base_pipeline_config import BasePipelineConfig
from invoke_training.config.data.data_loader_config import TextualInversionSDDataLoaderConfig
from invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig, ProdigyOptimizerConfig
class SdTextualInversionConfig(BasePipelineConfig):
type: Literal["SD_TEXTUAL_INVERSION"] = "SD_TEXTUAL_INVERSION"
"""Must be `SD_TEXTUAL_INVERSION`. This is what differentiates training pipeline types.
"""
model: str
"""Name or path of the base model to train. Can be in diffusers format, or a single stable diffusion checkpoint
file. (E.g. `"runwayml/stable-diffusion-v1-5"`, `"stabilityai/stable-diffusion-xl-base-1.0"`,
`"/path/to/local/model.safetensors"`, etc.)
The model architecture must match the training pipeline being run. For example, if running a
Textual Inversion SDXL pipeline, then `model` must refer to an SDXL model.
"""
hf_variant: str | None = "fp16"
"""The Hugging Face Hub model variant to use. Only applies if `model` is a Hugging Face Hub model name.
"""
# Helpful discussion for understanding how this works at inference time:
# https://github.com/huggingface/diffusers/pull/3144#discussion_r1172413509
num_vectors: int = 1
"""Note: `num_vectors` can be overridden by `initial_phrase`.
The number of textual inversion embedding vectors that will be used to learn the concept.
Increasing the `num_vectors` enables the model to learn more complex concepts, but has the following drawbacks:
- greater risk of overfitting
- increased size of the resulting output file
- consumes more of the prompt capacity at inference time
Typical values for `num_vectors` are in the range [1, 16].
As a rule of thumb, `num_vectors` can be increased as the size of the dataset increases (without overfitting).
"""
placeholder_token: str
"""The special word to associate the learned embeddings with. Choose a unique token that is unlikely to already
exist in the tokenizer's vocabulary.
"""
initializer_token: str | None = None
"""Note: Exactly one of `initializer_token`, `initial_embedding_file`, or `initial_phrase` should be set.
A vocabulary token to use as an initializer for the placeholder token. It should be a single word that roughly
describes the object or style that you're trying to train on. Must map to a single tokenizer token.
For example, if you are training on a dataset of images of your pet dog, a good choice would be `dog`.
"""
initial_embedding_file: str | None = None
"""Note: Exactly one of `initializer_token`, `initial_embedding_file`, or `initial_phrase` should be set.
Path to an existing TI embedding that will be used to initialize the embedding being trained. The placeholder
token in the file must match the `placeholder_token` field.
Either `initializer_token` or `initial_embedding_file` should be set.
"""
initial_phrase: str | None = None
"""Note: Exactly one of `initializer_token`, `initial_embedding_file`, or `initial_phrase` should be set.
A phrase that will be used to initialize the placeholder token embedding. The phrase will be tokenized, and the
corresponding embeddings will be used to initialize the placeholder tokens. The number of embedding vectors will be
inferred from the length of the tokenized phrase, so keep the phrase short. The consequences of training a large
number of embedding vectors are discussed in the `num_vectors` field documentation.
For example, if you are training on a dataset of images of pokemon, you might use `pokemon sketch white background`.
"""
optimizer: AdamOptimizerConfig | ProdigyOptimizerConfig = AdamOptimizerConfig()
lr_scheduler: Literal[
"linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"
] = "constant"
lr_warmup_steps: int = 0
"""The number of warmup steps in the learning rate scheduler. Only applied to schedulers that support warmup.
See lr_scheduler.
"""
min_snr_gamma: float | None = 5.0
"""Min-SNR weighting for diffusion training was introduced in https://arxiv.org/abs/2303.09556. This strategy
improves the speed of training convergence by adjusting the weight of each sample.
`min_snr_gamma` acts like an an upper bound on the weight of samples with low noise levels.
If `None`, then Min-SNR weighting will not be applied. If enabled, the recommended value is `min_snr_gamma = 5.0`.
"""
cache_vae_outputs: bool = False
"""If True, the VAE will be applied to all of the images in the dataset before starting training and the results
will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the VAE in VRAM), and
speeds up training (don't have to run the VAE encoding step).
This option can only be enabled if all non-deterministic image augmentations are disabled (i.e. `center_crop=True`,
`random_flip=False`, etc.).
"""
enable_cpu_offload_during_validation: bool = False
"""If True, models will be kept in CPU memory and loaded into GPU memory one-by-one while generating validation
images. This reduces VRAM requirements at the cost of slower generation of validation images.
"""
gradient_accumulation_steps: int = 1
"""The number of gradient steps to accumulate before each weight update. This is an alternative to increasing the
`train_batch_size` when training with limited VRAM.
"""
weight_dtype: Literal["float32", "float16", "bfloat16"] = "bfloat16"
"""All weights (trainable and fixed) will be cast to this precision. Lower precision dtypes require less VRAM, and
result in faster training, but are more prone to issues with numerical stability.
Recommendations:
- `"float32"`: Use this mode if you have plenty of VRAM available.
- `"bfloat16"`: Use this mode if you have limited VRAM and a GPU that supports bfloat16.
- `"float16"`: Use this mode if you have limited VRAM and a GPU that does not support bfloat16.
See also [`mixed_precision`][invoke_training.pipelines.stable_diffusion.textual_inversion.config.SdTextualInversionConfig.mixed_precision].
""" # noqa: E501
mixed_precision: Literal["no", "fp16", "bf16", "fp8"] = "no"
"""The mixed precision mode to use.
If mixed precision is enabled, then all non-trainable parameters will be cast to the specified `weight_dtype`, and
trainable parameters are kept in float32 precision to avoid issues with numerical stability.
This value is passed to Hugging Face Accelerate. See
[`accelerate.Accelerator.mixed_precision`](https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.mixed_precision)
for more details.
""" # noqa: E501
xformers: bool = False
"""If `True`, use xformers for more efficient attention blocks.
"""
gradient_checkpointing: bool = False
"""Whether or not to use gradient checkpointing to save memory at the expense of a slower backward pass. Enabling
gradient checkpointing slows down training by ~20%.
"""
max_checkpoints: int | None = None
"""The maximum number of checkpoints to keep. New checkpoints will replace earlier checkpoints to stay under this
limit. Note that this limit is applied to 'step' and 'epoch' checkpoints separately.
"""
prediction_type: Literal["epsilon", "v_prediction"] | None = None
"""The prediction type that will be used for training. If `None`, the prediction type will be inferred from the
scheduler.
"""
max_grad_norm: float | None = None
"""Maximum gradient norm for gradient clipping. Set to `null` or 0 for no clipping.
"""
validation_prompts: list[str] = []
"""A list of prompts that will be used to generate images throughout training for the purpose of tracking progress.
"""
negative_validation_prompts: list[str] | None = None
"""A list of negative prompts that will be applied when generating validation images. If set, this list should have
the same length as 'validation_prompts'.
"""
num_validation_images_per_prompt: int = 4
"""The number of validation images to generate for each prompt in `validation_prompts`. Careful, validation can
become very slow if this number is too large.
"""
train_batch_size: int = 4
"""The training batch size.
"""
use_masks: bool = False
"""If True, image masks will be applied to weight the loss during training. The dataset must contain masks for this
feature to be used.
"""
data_loader: TextualInversionSDDataLoaderConfig
"""The data configuration.
See
[`TextualInversionSDDataLoaderConfig`][invoke_training.config.data.data_loader_config.TextualInversionSDDataLoaderConfig]
for details.
"""
@model_validator(mode="after")
def check_validation_prompts(self):
if self.negative_validation_prompts is not None and len(self.negative_validation_prompts) != len(
self.validation_prompts
):
raise ValueError(
f"The number of validation_prompts ({len(self.validation_prompts)}) must match the number of "
f"negative_validation_prompts ({len(self.negative_validation_prompts)})."
)
return self
================================================
FILE: src/invoke_training/pipelines/stable_diffusion/textual_inversion/train.py
================================================
import json
import logging
import math
import os
import tempfile
import time
import torch
from accelerate import Accelerator
from accelerate.utils import set_seed
from diffusers.optimization import get_scheduler
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer, PreTrainedTokenizer
from invoke_training._shared.accelerator.accelerator_utils import (
get_dtype_from_str,
initialize_accelerator,
initialize_logging,
)
from invoke_training._shared.checkpoints.checkpoint_tracker import CheckpointTracker
from invoke_training._shared.checkpoints.serialization import save_state_dict
from invoke_training._shared.data.data_loaders.textual_inversion_sd_dataloader import (
build_textual_inversion_sd_dataloader,
)
from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import log_aspect_ratio_buckets
from invoke_training._shared.optimizer.optimizer_utils import initialize_optimizer
from invoke_training._shared.stable_diffusion.model_loading_utils import load_models_sd
from invoke_training._shared.stable_diffusion.textual_inversion import (
initialize_placeholder_tokens_from_initial_embedding,
initialize_placeholder_tokens_from_initial_phrase,
initialize_placeholder_tokens_from_initializer_token,
restore_original_embeddings,
)
from invoke_training._shared.stable_diffusion.validation import generate_validation_images_sd
from invoke_training._shared.utils.import_xformers import import_xformers
from invoke_training.pipelines.callbacks import ModelCheckpoint, ModelType, PipelineCallbacks, TrainingCheckpoint
from invoke_training.pipelines.stable_diffusion.lora.train import cache_vae_outputs, train_forward
from invoke_training.pipelines.stable_diffusion.textual_inversion.config import SdTextualInversionConfig
def _save_ti_embeddings(
epoch: int,
step: int,
text_encoder: CLIPTextModel,
placeholder_token_ids: list[int],
accelerator: Accelerator,
logger: logging.Logger,
checkpoint_tracker: CheckpointTracker,
callbacks: list[PipelineCallbacks] | None,
):
"""Save a Textual Inversion checkpoint. Old checkpoints are deleted if necessary to respect the checkpoint_tracker
limits.
"""
# Prune checkpoints and get new checkpoint path.
num_pruned = checkpoint_tracker.prune(1)
if num_pruned > 0:
logger.info(f"Pruned {num_pruned} checkpoint(s).")
save_path = checkpoint_tracker.get_path(epoch=epoch, step=step)
learned_embeds = (
accelerator.unwrap_model(text_encoder)
.get_input_embeddings()
.weight[min(placeholder_token_ids) : max(placeholder_token_ids) + 1]
)
learned_embeds_dict = {"emb_params": learned_embeds.detach().cpu().to(torch.float32)}
save_state_dict(learned_embeds_dict, save_path)
if callbacks is not None:
for cb in callbacks:
cb.on_save_checkpoint(
TrainingCheckpoint(
models=[ModelCheckpoint(file_path=save_path, model_type=ModelType.SD1_TEXTUAL_INVERSION)],
epoch=epoch,
step=step,
)
)
def _initialize_placeholder_tokens(
config: SdTextualInversionConfig,
tokenizer: CLIPTokenizer,
text_encoder: PreTrainedTokenizer,
logger: logging.Logger,
) -> tuple[list[str], list[int]]:
"""Prepare the tokenizer and text_encoder for TI training.
- Add the placeholder tokens to the tokenizer.
- Add new token embeddings to the text_encoder for each of the placeholder tokens.
- Initialize the new token embeddings from either an existing token, or an initial TI embedding file.
"""
if (
sum(
[
config.initializer_token is not None,
config.initial_embedding_file is not None,
config.initial_phrase is not None,
]
)
!= 1
):
raise ValueError(
"Exactly one of 'initializer_token', 'initial_embedding_file', or 'initial_phrase' should be set."
)
if config.initializer_token is not None:
placeholder_tokens, placeholder_token_ids = initialize_placeholder_tokens_from_initializer_token(
tokenizer=tokenizer,
text_encoder=text_encoder,
initializer_token=config.initializer_token,
placeholder_token=config.placeholder_token,
num_vectors=config.num_vectors,
logger=logger,
)
elif config.initial_embedding_file is not None:
placeholder_tokens, placeholder_token_ids = initialize_placeholder_tokens_from_initial_embedding(
tokenizer=tokenizer,
text_encoder=text_encoder,
initial_embedding_file=config.initial_embedding_file,
placeholder_token=config.placeholder_token,
num_vectors=config.num_vectors,
)
elif config.initial_phrase is not None:
placeholder_tokens, placeholder_token_ids = initialize_placeholder_tokens_from_initial_phrase(
tokenizer=tokenizer,
text_encoder=text_encoder,
initial_phrase=config.initial_phrase,
placeholder_token=config.placeholder_token,
)
else:
raise ValueError(
"Exactly one of 'initializer_token', 'initial_embedding_file', or 'initial_phrase' should be set."
)
return placeholder_tokens, placeholder_token_ids
def train(config: SdTextualInversionConfig, callbacks: list[PipelineCallbacks] | None = None): # noqa: C901
# Create a timestamped directory for all outputs.
out_dir = os.path.join(config.base_output_dir, f"{time.time()}")
ckpt_dir = os.path.join(out_dir, "checkpoints")
os.makedirs(ckpt_dir)
accelerator = initialize_accelerator(
out_dir, config.gradient_accumulation_steps, config.mixed_precision, config.report_to
)
logger = initialize_logging(os.path.basename(__file__), accelerator)
# Set the accelerate seed.
if config.seed is not None:
set_seed(config.seed)
# Log the accelerator configuration from every process to help with debugging.
logger.info(accelerator.state, main_process_only=False)
logger.info("Starting Textual Inversion Training.")
logger.info(f"Configuration:\n{json.dumps(config.dict(), indent=2, default=str)}")
logger.info(f"Output dir: '{out_dir}'")
# Write the configuration to disk.
with open(os.path.join(out_dir, "config.json"), "w") as f:
json.dump(config.dict(), f, indent=2, default=str)
weight_dtype = get_dtype_from_str(config.weight_dtype)
logger.info("Loading models.")
tokenizer, noise_scheduler, text_encoder, vae, unet = load_models_sd(
logger=logger, model_name_or_path=config.model, hf_variant=config.hf_variant, dtype=weight_dtype
)
placeholder_tokens, placeholder_token_ids = _initialize_placeholder_tokens(
config=config, tokenizer=tokenizer, text_encoder=text_encoder, logger=logger
)
logger.info(f"Initialized {len(placeholder_tokens)} placeholder tokens: {placeholder_tokens}.")
# All parameters of the VAE, UNet, and text encoder are currently frozen. Just unfreeze the token embeddings in the
# text encoder.
text_encoder.text_model.embeddings.token_embedding.requires_grad_(True)
if config.gradient_checkpointing:
# We want to enable gradient checkpointing in the UNet regardless of whether it is being trained.
unet.enable_gradient_checkpointing()
# unet must be in train() mode for gradient checkpointing to take effect.
# At the time of writing, the unet dropout probabilities default to 0, so putting the unet in train mode does
# not change its forward behavior.
unet.train()
# The text_encoder will be put in .train() mode later, so we don't need to worry about that here.
# Note: There are some weird interactions gradient checkpointing and requires_grad_() when training a
# text_encoder LoRA. If this code ever gets copied elsewhere, make sure to take a look at how this is handled in
# other training pipelines.
text_encoder.gradient_checkpointing_enable()
if config.xformers:
import_xformers()
unet.enable_xformers_memory_efficient_attention()
vae.enable_xformers_memory_efficient_attention()
# Prepare VAE output cache.
vae_output_cache_dir_name = None
if config.cache_vae_outputs:
if config.data_loader.random_flip:
raise ValueError("'cache_vae_outputs' cannot be True if 'random_flip' is True.")
if not config.data_loader.center_crop:
raise ValueError("'cache_vae_outputs' cannot be True if 'center_crop' is False.")
# We use a temporary directory for the cache. The directory will automatically be cleaned up when
# tmp_vae_output_cache_dir is destroyed.
tmp_vae_output_cache_dir = tempfile.TemporaryDirectory()
vae_output_cache_dir_name = tmp_vae_output_cache_dir.name
if accelerator.is_local_main_process:
# Only the main process should populate the cache.
logger.info(f"Generating VAE output cache ('{vae_output_cache_dir_name}').")
vae.to(accelerator.device, dtype=weight_dtype)
data_loader = build_textual_inversion_sd_dataloader(
config=config.data_loader,
placeholder_token=config.placeholder_token,
batch_size=config.train_batch_size,
use_masks=config.use_masks,
shuffle=False,
)
cache_vae_outputs(vae_output_cache_dir_name, data_loader, vae)
# Move the VAE back to the CPU, because it is not needed for training.
vae.to("cpu")
accelerator.wait_for_everyone()
else:
vae.to(accelerator.device, dtype=weight_dtype)
unet.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)
# Initialize the optimizer to only optimize the token embeddings.
optimizer = initialize_optimizer(config.optimizer, text_encoder.get_input_embeddings().parameters())
data_loader = build_textual_inversion_sd_dataloader(
config=config.data_loader,
placeholder_token=config.placeholder_token,
batch_size=config.train_batch_size,
use_masks=config.use_masks,
vae_output_cache_dir=vae_output_cache_dir_name,
)
log_aspect_ratio_buckets(logger=logger, batch_sampler=data_loader.batch_sampler)
assert sum([config.max_train_steps is not None, config.max_train_epochs is not None]) == 1
assert sum([config.save_every_n_steps is not None, config.save_every_n_epochs is not None]) == 1
assert sum([config.validate_every_n_steps is not None, config.validate_every_n_epochs is not None]) == 1
# A "step" represents a single weight update operation (i.e. takes into account gradient accumulation steps).
# math.ceil(...) is used in calculating the num_steps_per_epoch, because by default an optimizer step is taken when
# the end of the dataloader is reached, even if gradient_accumulation_steps hasn't been reached.
num_steps_per_epoch = math.ceil(len(data_loader) / config.gradient_accumulation_steps)
num_train_steps = config.max_train_steps or config.max_train_epochs * num_steps_per_epoch
num_train_epochs = math.ceil(num_train_steps / num_steps_per_epoch)
# TODO(ryand): Test in a distributed training environment and more clearly document the rationale for scaling steps
# by the number of processes. This scaling logic was copied from the diffusers example training code, but it appears
# in many places so I don't know where it originated. Internally, accelerate makes one LR scheduler step per process
# (https://github.com/huggingface/accelerate/blame/49cb83a423f2946059117d8bb39b7c8747d29d80/src/accelerate/scheduler.py#L72-L82),
# so the scaling here simply reverses that behaviour.
lr_scheduler: torch.optim.lr_scheduler.LRScheduler = get_scheduler(
config.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=config.lr_warmup_steps * accelerator.num_processes,
num_training_steps=num_train_steps * accelerator.num_processes,
)
# Prepare everything with our `accelerator`.
text_encoder, optimizer, data_loader, lr_scheduler = accelerator.prepare(
text_encoder, optimizer, data_loader, lr_scheduler
)
prepared_result: tuple[
CLIPTextModel, torch.optim.Optimizer, torch.utils.data.DataLoader, torch.optim.lr_scheduler.LRScheduler
] = accelerator.prepare(text_encoder, optimizer, data_loader, lr_scheduler)
text_encoder, optimizer, data_loader, lr_scheduler = prepared_result
if accelerator.is_main_process:
accelerator.init_trackers("textual_inversion_training")
# Tensorboard uses markdown formatting, so we wrap the config json in a code block.
accelerator.log({"configuration": f"```json\n{json.dumps(config.dict(), indent=2, default=str)}\n```\n"})
checkpoint_tracker = CheckpointTracker(
base_dir=ckpt_dir,
prefix="checkpoint",
extension=".safetensors",
max_checkpoints=config.max_checkpoints,
)
# Train!
total_batch_size = config.train_batch_size * accelerator.num_processes * config.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num batches = {len(data_loader)}")
logger.info(f" Instantaneous batch size per device = {config.train_batch_size}")
logger.info(f" Gradient accumulation steps = {config.gradient_accumulation_steps}")
logger.info(f" Parallel processes = {accelerator.num_processes}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Total optimization steps = {num_train_steps}")
logger.info(f" Total epochs = {num_train_epochs}")
global_step = 0
first_epoch = 0
completed_epochs = 0
progress_bar = tqdm(
range(global_step, num_train_steps),
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)
progress_bar.set_description("Steps")
# Keep original embeddings as reference.
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()
def save_checkpoint(num_completed_epochs: int, num_completed_steps: int):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
_save_ti_embeddings(
epoch=num_completed_epochs,
step=num_completed_steps,
text_encoder=text_encoder,
placeholder_token_ids=placeholder_token_ids,
accelerator=accelerator,
logger=logger,
checkpoint_tracker=checkpoint_tracker,
callbacks=callbacks,
)
accelerator.wait_for_everyone()
def validate(num_completed_epochs: int, num_completed_steps: int):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
generate_validation_images_sd(
epoch=num_completed_epochs,
step=num_completed_steps,
out_dir=out_dir,
accelerator=accelerator,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
noise_scheduler=noise_scheduler,
unet=unet,
config=config,
logger=logger,
callbacks=callbacks,
)
accelerator.wait_for_everyone()
for epoch in range(first_epoch, num_train_epochs):
text_encoder.train()
train_loss = 0.0
for data_batch_idx, data_batch in enumerate(data_loader):
with accelerator.accumulate(text_encoder):
loss = train_forward(
config=config,
data_batch=data_batch,
vae=vae,
noise_scheduler=noise_scheduler,
tokenizer=tokenizer,
text_encoder=text_encoder,
unet=unet,
weight_dtype=weight_dtype,
use_masks=config.use_masks,
min_snr_gamma=config.min_snr_gamma,
)
# Gather the losses across all processes for logging (if we use distributed training).
# TODO(ryand): Test that this works properly with distributed training.
avg_loss = accelerator.gather(loss.repeat(config.train_batch_size)).mean()
train_loss += avg_loss.item() / config.gradient_accumulation_steps
accelerator.backward(loss)
if accelerator.sync_gradients and config.max_grad_norm is not None:
# TODO(ryand): I copied this from another pipeline. Should probably just clip the trainable params.
params_to_clip = text_encoder.parameters()
accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Make sure we don't update any embedding weights besides the newly-added token(s).
# TODO(ryand): Should we only do this if accelerator.sync_gradients?
restore_original_embeddings(
tokenizer=tokenizer,
placeholder_token_ids=placeholder_token_ids,
accelerator=accelerator,
text_encoder=text_encoder,
orig_embeds_params=orig_embeds_params,
)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
completed_epochs = epoch if (data_batch_idx + 1) < len(data_loader) else epoch + 1
log = {"train_loss": train_loss, "lr": lr_scheduler.get_last_lr()[0]}
if config.optimizer.optimizer_type == "Prodigy":
# TODO(ryand): Test Prodigy logging.
log["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
accelerator.log(log, step=global_step)
train_loss = 0.0
# global_step represents the *number of completed steps* at this point.
if config.save_every_n_steps is not None and global_step % config.save_every_n_steps == 0:
save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step)
if (
config.validate_every_n_steps is not None
and global_step % config.validate_every_n_steps == 0
and len(config.validation_prompts) > 0
):
validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step)
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= num_train_steps:
break
# Save a checkpoint every n epochs.
if config.save_every_n_epochs is not None and completed_epochs % config.save_every_n_epochs == 0:
save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step)
# Generate validation images every n epochs.
if (
config.validate_every_n_epochs is not None
and completed_epochs % config.validate_every_n_epochs == 0
and len(config.validation_prompts) > 0
):
validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step)
accelerator.end_training()
================================================
FILE: src/invoke_training/pipelines/stable_diffusion_xl/__init__.py
================================================
================================================
FILE: src/invoke_training/pipelines/stable_diffusion_xl/finetune/__init__.py
================================================
================================================
FILE: src/invoke_training/pipelines/stable_diffusion_xl/finetune/config.py
================================================
from typing import Annotated, Literal, Union
from pydantic import Field, model_validator
from invoke_training.config.base_pipeline_config import BasePipelineConfig
from invoke_training.config.data.data_loader_config import DreamboothSDDataLoaderConfig, ImageCaptionSDDataLoaderConfig
from invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig, ProdigyOptimizerConfig
class SdxlFinetuneConfig(BasePipelineConfig):
type: Literal["SDXL_FINETUNE"] = "SDXL_FINETUNE"
model: str = "stabilityai/stable-diffusion-xl-base-1.0"
"""Name or path of the base model to train. Can be in diffusers format, or a single stable diffusion checkpoint
file. (E.g. 'stabilityai/stable-diffusion-xl-base-1.0', '/path/to/JuggernautXL.safetensors', etc. )
"""
hf_variant: str | None = "fp16"
"""The Hugging Face Hub model variant to use. Only applies if `model` is a Hugging Face Hub model name.
"""
save_checkpoint_format: Literal["full_diffusers", "trained_only_diffusers"] = "trained_only_diffusers"
"""The save format for the checkpoints.
Options:
- `full_diffusers`: Save the full model in diffusers format (including models that weren't finetuned). If you want a
single output artifact that can be used for generation, then this is the recommended option.
- `trained_only_diffusers`: Save only the models that were finetuned in diffusers format. For example, if only the
UNet model was trained, then only the UNet model will be saved. This option will significantly reduce the disk space
consumed by the saved checkpoints. If you plan to extract a LoRA from the fine-tuned model, then this is the
recommended option.
"""
save_dtype: Literal["float32", "float16", "bfloat16"] = "float16"
"""The dtype to use when saving the model.
"""
optimizer: AdamOptimizerConfig | ProdigyOptimizerConfig = AdamOptimizerConfig()
lr_scheduler: Literal[
"linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"
] = "constant"
lr_warmup_steps: int = 0
"""The number of warmup steps in the learning rate scheduler. Only applied to schedulers that support warmup.
See lr_scheduler.
"""
min_snr_gamma: float | None = 5.0
"""Min-SNR weighting for diffusion training was introduced in https://arxiv.org/abs/2303.09556. This strategy
improves the speed of training convergence by adjusting the weight of each sample.
`min_snr_gamma` acts like an an upper bound on the weight of samples with low noise levels.
If `None`, then Min-SNR weighting will not be applied. If enabled, the recommended value is `min_snr_gamma = 5.0`.
"""
cache_text_encoder_outputs: bool = False
"""If True, the text encoder(s) will be applied to all of the captions in the dataset before starting training and
the results will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the
text encoders in VRAM), and speeds up training (don't have to run the text encoders for each training example).
This option can only be enabled if `train_text_encoder == False` and there are no caption augmentations being
applied.
"""
cache_vae_outputs: bool = False
"""If True, the VAE will be applied to all of the images in the dataset before starting training and the results
will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the VAE in VRAM), and
speeds up training (don't have to run the VAE encoding step). This option can only be enabled if all
non-deterministic image augmentations are disabled (i.e. center_crop=True, random_flip=False).
"""
enable_cpu_offload_during_validation: bool = False
"""If True, models will be kept in CPU memory and loaded into GPU memory one-by-one while generating validation
images. This reduces VRAM requirements at the cost of slower generation of validation images.
"""
gradient_accumulation_steps: int = 1
"""The number of gradient steps to accumulate before each weight update. This value is passed to Hugging Face
Accelerate. This is an alternative to increasing the batch size when training with limited VRAM.
"""
weight_dtype: Literal["float32", "float16", "bfloat16"] = "bfloat16"
"""All weights (trainable and fixed) will be cast to this precision. Lower precision dtypes require less VRAM, and
result in faster training, but are more prone to issues with numerical stability.
Recommendations:
- `"float32"`: Use this mode if you have plenty of VRAM available.
- `"bfloat16"`: Use this mode if you have limited VRAM and a GPU that supports bfloat16.
- `"float16"`: Use this mode if you have limited VRAM and a GPU that does not support bfloat16.
See also [`mixed_precision`][invoke_training.pipelines.stable_diffusion_xl.lora.config.SdxlLoraConfig.mixed_precision].
""" # noqa: E501
mixed_precision: Literal["no", "fp16", "bf16", "fp8"] = "no"
"""The mixed precision mode to use.
If mixed precision is enabled, then all non-trainable parameters will be cast to the specified `weight_dtype`, and
trainable parameters are kept in float32 precision to avoid issues with numerical stability.
This value is passed to Hugging Face Accelerate. See
[`accelerate.Accelerator.mixed_precision`](https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.mixed_precision)
for more details.
""" # noqa: E501
xformers: bool = False
"""If true, use xformers for more efficient attention blocks.
"""
gradient_checkpointing: bool = False
"""Whether or not to use gradient checkpointing to save memory at the expense of a slower backward pass. Enabling
gradient checkpointing slows down training by ~20%.
"""
max_checkpoints: int | None = None
"""The maximum number of checkpoints to keep. New checkpoints will replace earlier checkpoints to stay under this
limit. Note that this limit is applied to 'step' and 'epoch' checkpoints separately.
"""
prediction_type: Literal["epsilon", "v_prediction"] | None = None
"""The prediction_type that will be used for training. Choose between 'epsilon' or 'v_prediction' or leave 'None'.
If 'None', the prediction type of the scheduler: `noise_scheduler.config.prediction_type` is used.
"""
max_grad_norm: float | None = None
"""Max gradient norm for clipping. Set to null or 0 for no clipping.
"""
validation_prompts: list[str] = []
"""A list of prompts that will be used to generate images throughout training for the purpose of tracking progress.
See also 'validate_every_n_epochs'.
"""
negative_validation_prompts: list[str] | None = None
"""A list of negative prompts that will be applied when generating validation images. If set, this list should have
the same length as 'validation_prompts'.
"""
num_validation_images_per_prompt: int = 4
"""The number of validation images to generate for each prompt in 'validation_prompts'. Careful, validation can
become quite slow if this number is too large.
"""
train_batch_size: int = 4
"""The training batch size.
"""
use_masks: bool = False
"""If True, image masks will be applied to weight the loss during training. The dataset must contain masks for this
feature to be used.
"""
data_loader: Annotated[
Union[ImageCaptionSDDataLoaderConfig, DreamboothSDDataLoaderConfig], Field(discriminator="type")
]
vae_model: str | None = None
"""The name of the Hugging Face Hub VAE model to train against. This will override the VAE bundled with the base
model (specified by the `model` parameter). This config option is provided for SDXL models, because SDXL shipped
with a VAE that produces NaNs in fp16 mode, so it is common to replace this VAE with a fixed version.
"""
@model_validator(mode="after")
def check_validation_prompts(self):
if self.negative_validation_prompts is not None and len(self.negative_validation_prompts) != len(
self.validation_prompts
):
raise ValueError(
f"The number of validation_prompts ({len(self.validation_prompts)}) must match the number of "
f"negative_validation_prompts ({len(self.negative_validation_prompts)})."
)
return self
================================================
FILE: src/invoke_training/pipelines/stable_diffusion_xl/finetune/train.py
================================================
import itertools
import json
import logging
import math
import os
import tempfile
import time
from typing import Literal
import peft
import torch
import torch.utils.data
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from invoke_training._shared.accelerator.accelerator_utils import (
get_dtype_from_str,
initialize_accelerator,
initialize_logging,
)
from invoke_training._shared.checkpoints.checkpoint_tracker import CheckpointTracker
from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import log_aspect_ratio_buckets
from invoke_training._shared.optimizer.optimizer_utils import initialize_optimizer
from invoke_training._shared.stable_diffusion.checkpoint_utils import (
save_sdxl_diffusers_checkpoint,
save_sdxl_diffusers_unet_checkpoint,
)
from invoke_training._shared.stable_diffusion.model_loading_utils import load_models_sdxl
from invoke_training._shared.stable_diffusion.validation import generate_validation_images_sdxl
from invoke_training._shared.utils.import_xformers import import_xformers
from invoke_training.pipelines.callbacks import ModelCheckpoint, ModelType, PipelineCallbacks, TrainingCheckpoint
from invoke_training.pipelines.stable_diffusion.lora.train import cache_vae_outputs
from invoke_training.pipelines.stable_diffusion_xl.finetune.config import SdxlFinetuneConfig
from invoke_training.pipelines.stable_diffusion_xl.lora.train import (
_build_data_loader,
cache_text_encoder_outputs,
train_forward,
)
def _save_sdxl_checkpoint(
epoch: int,
step: int,
save_checkpoint_format: Literal["full_diffusers", "trained_only_diffusers"],
vae: AutoencoderKL,
text_encoder_1: CLIPTextModel,
text_encoder_2: CLIPTextModel,
tokenizer_1: CLIPTokenizer,
tokenizer_2: CLIPTokenizer,
noise_scheduler: DDPMScheduler,
unet: UNet2DConditionModel,
save_dtype: torch.dtype,
logger: logging.Logger,
checkpoint_tracker: CheckpointTracker,
callbacks: list[PipelineCallbacks] | None,
):
# Prune checkpoints and get new checkpoint path.
num_pruned = checkpoint_tracker.prune(1)
if num_pruned > 0:
logger.info(f"Pruned {num_pruned} checkpoint(s).")
save_path = checkpoint_tracker.get_path(epoch=epoch, step=step)
if save_checkpoint_format == "trained_only_diffusers":
model_type = ModelType.SDXL_UNET_DIFFUSERS
save_sdxl_diffusers_unet_checkpoint(checkpoint_path=save_path, unet=unet, save_dtype=save_dtype)
elif save_checkpoint_format == "full_diffusers":
model_type = ModelType.SDXL_FULL_DIFFUSERS
save_sdxl_diffusers_checkpoint(
checkpoint_path=save_path,
vae=vae,
text_encoder_1=text_encoder_1,
text_encoder_2=text_encoder_2,
tokenizer_1=tokenizer_1,
tokenizer_2=tokenizer_2,
noise_scheduler=noise_scheduler,
unet=unet,
save_dtype=save_dtype,
)
else:
raise ValueError(f"Invalid save_checkpoint_format: '{save_checkpoint_format}'.")
if callbacks is not None:
for cb in callbacks:
cb.on_save_checkpoint(
TrainingCheckpoint(
models=[ModelCheckpoint(file_path=save_path, model_type=model_type)],
epoch=epoch,
step=step,
)
)
def train(config: SdxlFinetuneConfig, callbacks: list[PipelineCallbacks] | None = None): # noqa: C901
# Give a clear error message if an unsupported base model was chosen.
# TODO(ryan): Update this check to work with single-file SD checkpoints.
# check_base_model_version(
# {BaseModelVersionEnum.STABLE_DIFFUSION_SDXL_BASE},
# config.model,
# local_files_only=False,
# )
# Create a timestamped directory for all outputs.
out_dir = os.path.join(config.base_output_dir, f"{time.time()}")
ckpt_dir = os.path.join(out_dir, "checkpoints")
os.makedirs(ckpt_dir)
accelerator = initialize_accelerator(
out_dir, config.gradient_accumulation_steps, config.mixed_precision, config.report_to
)
logger = initialize_logging(os.path.basename(__file__), accelerator)
# Set the accelerate seed.
if config.seed is not None:
set_seed(config.seed)
# Log the accelerator configuration from every process to help with debugging.
logger.info(accelerator.state, main_process_only=False)
logger.info("Starting Training.")
logger.info(f"Configuration:\n{json.dumps(config.dict(), indent=2, default=str)}")
logger.info(f"Output dir: '{out_dir}'")
# Write the configuration to disk.
with open(os.path.join(out_dir, "config.json"), "w") as f:
json.dump(config.dict(), f, indent=2, default=str)
weight_dtype = get_dtype_from_str(config.weight_dtype)
logger.info("Loading models.")
tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, unet = load_models_sdxl(
logger=logger,
model_name_or_path=config.model,
hf_variant=config.hf_variant,
vae_model=config.vae_model,
base_embeddings=None,
dtype=weight_dtype,
)
if config.xformers:
import_xformers()
# TODO(ryand): There is a known issue if xformers is enabled when training in mixed precision where xformers
# will fail because Q, K, V have different dtypes.
unet.enable_xformers_memory_efficient_attention()
vae.enable_xformers_memory_efficient_attention()
# Prepare text encoder output cache.
text_encoder_output_cache_dir_name = None
if config.cache_text_encoder_outputs:
# TODO(ryand): Think about how to better check if it is safe to cache the text encoder outputs. Currently, there
# are a number of configurations that would cause variation in the text encoder outputs and should not be used
# with caching.
# We use a temporary directory for the cache. The directory will automatically be cleaned up when
# tmp_text_encoder_output_cache_dir is destroyed.
tmp_text_encoder_output_cache_dir = tempfile.TemporaryDirectory()
text_encoder_output_cache_dir_name = tmp_text_encoder_output_cache_dir.name
if accelerator.is_local_main_process:
# Only the main process should populate the cache.
logger.info(f"Generating text encoder output cache ('{text_encoder_output_cache_dir_name}').")
text_encoder_1.to(accelerator.device, dtype=weight_dtype)
text_encoder_2.to(accelerator.device, dtype=weight_dtype)
# TODO(ryan): Move cache_text_encoder_outputs to a shared location so that it is not imported from another
# pipeline.
cache_text_encoder_outputs(
text_encoder_output_cache_dir_name, config, tokenizer_1, tokenizer_2, text_encoder_1, text_encoder_2
)
# Move the text_encoders back to the CPU, because they are not needed for training.
text_encoder_1.to("cpu")
text_encoder_2.to("cpu")
accelerator.wait_for_everyone()
else:
text_encoder_1.to(accelerator.device, dtype=weight_dtype)
text_encoder_2.to(accelerator.device, dtype=weight_dtype)
# Prepare VAE output cache.
vae_output_cache_dir_name = None
if config.cache_vae_outputs:
if config.data_loader.random_flip:
raise ValueError("'cache_vae_outputs' cannot be True if 'random_flip' is True.")
if not config.data_loader.center_crop:
raise ValueError("'cache_vae_outputs' cannot be True if 'center_crop' is False.")
# We use a temporary directory for the cache. The directory will automatically be cleaned up when
# tmp_vae_output_cache_dir is destroyed.
tmp_vae_output_cache_dir = tempfile.TemporaryDirectory()
vae_output_cache_dir_name = tmp_vae_output_cache_dir.name
if accelerator.is_local_main_process:
# Only the main process should to populate the cache.
logger.info(f"Generating VAE output cache ('{vae_output_cache_dir_name}').")
vae.to(accelerator.device, dtype=weight_dtype)
# TODO(ryan): Move cache_text_encoder_outputs to a shared location so that it is not imported from another
# pipeline.
data_loader = _build_data_loader(
data_loader_config=config.data_loader,
batch_size=config.train_batch_size,
use_masks=config.use_masks,
shuffle=False,
sequential_batching=True,
)
cache_vae_outputs(vae_output_cache_dir_name, data_loader, vae)
# Move the VAE back to the CPU, because it is not needed for training.
vae.to("cpu")
accelerator.wait_for_everyone()
else:
vae.to(accelerator.device, dtype=weight_dtype)
unet.to(accelerator.device, dtype=weight_dtype)
# Make UNet trainable.
unet.requires_grad_(True)
unet.train()
all_trainable_models = [unet]
# If mixed_precision is enabled, cast all trainable params to float32.
if config.mixed_precision != "no":
for trainable_model in all_trainable_models:
for param in trainable_model.parameters():
if param.requires_grad:
param.data = param.to(torch.float32)
if config.gradient_checkpointing:
# We want to enable gradient checkpointing in the UNet regardless of whether it is being trained.
unet.enable_gradient_checkpointing()
# unet must be in train() mode for gradient checkpointing to take effect.
# At the time of writing, the unet dropout probabilities default to 0, so putting the unet in train mode does
# not change its forward behavior.
unet.train()
optimizer = initialize_optimizer(config.optimizer, unet.parameters())
data_loader = _build_data_loader(
data_loader_config=config.data_loader,
batch_size=config.train_batch_size,
use_masks=config.use_masks,
text_encoder_output_cache_dir=text_encoder_output_cache_dir_name,
vae_output_cache_dir=vae_output_cache_dir_name,
)
log_aspect_ratio_buckets(logger=logger, batch_sampler=data_loader.batch_sampler)
assert sum([config.max_train_steps is not None, config.max_train_epochs is not None]) == 1
assert sum([config.save_every_n_steps is not None, config.save_every_n_epochs is not None]) == 1
assert sum([config.validate_every_n_steps is not None, config.validate_every_n_epochs is not None]) == 1
# A "step" represents a single weight update operation (i.e. takes into account gradient accumulation steps).
# math.ceil(...) is used in calculating the num_steps_per_epoch, because by default an optimizer step is taken when
# the end of the dataloader is reached, even if gradient_accumulation_steps hasn't been reached.
num_steps_per_epoch = math.ceil(len(data_loader) / config.gradient_accumulation_steps)
num_train_steps = config.max_train_steps or config.max_train_epochs * num_steps_per_epoch
num_train_epochs = math.ceil(num_train_steps / num_steps_per_epoch)
# TODO(ryand): Test in a distributed training environment and more clearly document the rationale for scaling steps
# by the number of processes. This scaling logic was copied from the diffusers example training code, but it appears
# in many places so I don't know where it originated. Internally, accelerate makes one LR scheduler step per process
# (https://github.com/huggingface/accelerate/blame/49cb83a423f2946059117d8bb39b7c8747d29d80/src/accelerate/scheduler.py#L72-L82),
# so the scaling here simply reverses that behaviour.
lr_scheduler: torch.optim.lr_scheduler.LRScheduler = get_scheduler(
config.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=config.lr_warmup_steps * accelerator.num_processes,
num_training_steps=num_train_steps * accelerator.num_processes,
)
prepared_result: tuple[
UNet2DConditionModel,
peft.PeftModel | CLIPTextModel,
peft.PeftModel | CLIPTextModel,
torch.optim.Optimizer,
torch.utils.data.DataLoader,
torch.optim.lr_scheduler.LRScheduler,
] = accelerator.prepare(
unet,
text_encoder_1,
text_encoder_2,
optimizer,
data_loader,
lr_scheduler,
# Disable automatic device placement for text_encoder if the text encoder outputs were cached.
device_placement=[
True,
not config.cache_text_encoder_outputs,
not config.cache_text_encoder_outputs,
True,
True,
True,
],
)
unet, text_encoder_1, text_encoder_2, optimizer, data_loader, lr_scheduler = prepared_result
if accelerator.is_main_process:
accelerator.init_trackers("finetune")
# Tensorboard uses markdown formatting, so we wrap the config json in a code block.
accelerator.log({"configuration": f"```json\n{json.dumps(config.dict(), indent=2, default=str)}\n```\n"})
checkpoint_tracker = CheckpointTracker(
base_dir=ckpt_dir, prefix="checkpoint", max_checkpoints=config.max_checkpoints
)
# Train!
total_batch_size = config.train_batch_size * accelerator.num_processes * config.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num batches = {len(data_loader)}")
logger.info(f" Instantaneous batch size per device = {config.train_batch_size}")
logger.info(f" Gradient accumulation steps = {config.gradient_accumulation_steps}")
logger.info(f" Parallel processes = {accelerator.num_processes}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Total optimization steps = {num_train_steps}")
logger.info(f" Total epochs = {num_train_epochs}")
global_step = 0
first_epoch = 0
completed_epochs = 0
progress_bar = tqdm(
range(global_step, num_train_steps),
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)
progress_bar.set_description("Steps")
def save_checkpoint(num_completed_epochs: int, num_completed_steps: int):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
_save_sdxl_checkpoint(
epoch=num_completed_epochs,
step=num_completed_steps,
save_checkpoint_format=config.save_checkpoint_format,
vae=vae,
text_encoder_1=text_encoder_1,
text_encoder_2=text_encoder_2,
tokenizer_1=tokenizer_1,
tokenizer_2=tokenizer_2,
noise_scheduler=noise_scheduler,
unet=unet,
save_dtype=get_dtype_from_str(config.save_dtype),
logger=logger,
checkpoint_tracker=checkpoint_tracker,
callbacks=callbacks,
)
accelerator.wait_for_everyone()
def validate(num_completed_epochs: int, num_completed_steps: int):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
generate_validation_images_sdxl(
epoch=num_completed_epochs,
step=num_completed_steps,
out_dir=out_dir,
accelerator=accelerator,
vae=vae,
text_encoder_1=text_encoder_1,
text_encoder_2=text_encoder_2,
tokenizer_1=tokenizer_1,
tokenizer_2=tokenizer_2,
noise_scheduler=noise_scheduler,
unet=unet,
config=config,
logger=logger,
callbacks=callbacks,
)
accelerator.wait_for_everyone()
for epoch in range(first_epoch, num_train_epochs):
train_loss = 0.0
for data_batch_idx, data_batch in enumerate(data_loader):
with accelerator.accumulate(unet, text_encoder_1, text_encoder_2):
loss = train_forward(
accelerator=accelerator,
data_batch=data_batch,
vae=vae,
noise_scheduler=noise_scheduler,
tokenizer_1=tokenizer_1,
tokenizer_2=tokenizer_2,
text_encoder_1=text_encoder_1,
text_encoder_2=text_encoder_2,
unet=unet,
weight_dtype=weight_dtype,
resolution=config.data_loader.resolution,
use_masks=config.use_masks,
prediction_type=config.prediction_type,
min_snr_gamma=config.min_snr_gamma,
)
# Gather the losses across all processes for logging (if we use distributed training).
# TODO(ryand): Test that this works properly with distributed training.
avg_loss = accelerator.gather(loss.repeat(config.train_batch_size)).mean()
train_loss += avg_loss.item() / config.gradient_accumulation_steps
# Backpropagate.
accelerator.backward(loss)
if accelerator.sync_gradients and config.max_grad_norm is not None:
params_to_clip = itertools.chain.from_iterable([m.parameters() for m in all_trainable_models])
accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Checks if the accelerator has performed an optimization step behind the scenes.
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
completed_epochs = epoch if (data_batch_idx + 1) < len(data_loader) else epoch + 1
log = {"train_loss": train_loss}
lrs = lr_scheduler.get_last_lr()
# When training the UNet, it will always be the first parameter group.
log["lr/unet"] = float(lrs[0])
if config.optimizer.optimizer_type == "Prodigy":
log["lr/d*lr/unet"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
accelerator.log(log, step=global_step)
train_loss = 0.0
# global_step represents the *number of completed steps* at this point.
if config.save_every_n_steps is not None and global_step % config.save_every_n_steps == 0:
save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step)
if (
config.validate_every_n_steps is not None
and global_step % config.validate_every_n_steps == 0
and len(config.validation_prompts) > 0
):
validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step)
logs = {
"step_loss": loss.detach().item(),
"lr": lr_scheduler.get_last_lr()[0],
}
progress_bar.set_postfix(**logs)
if global_step >= num_train_steps:
break
# Save a checkpoint every n epochs.
if config.save_every_n_epochs is not None and completed_epochs % config.save_every_n_epochs == 0:
save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step)
# Generate validation images every n epochs.
if (
config.validate_every_n_epochs is not None
and completed_epochs % config.validate_every_n_epochs == 0
and len(config.validation_prompts) > 0
):
validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step)
accelerator.end_training()
================================================
FILE: src/invoke_training/pipelines/stable_diffusion_xl/lora/__init__.py
================================================
================================================
FILE: src/invoke_training/pipelines/stable_diffusion_xl/lora/config.py
================================================
from typing import Annotated, Literal, Union
from pydantic import Field, model_validator
from invoke_training._shared.stable_diffusion.lora_checkpoint_utils import (
TEXT_ENCODER_TARGET_MODULES,
UNET_TARGET_MODULES,
)
from invoke_training.config.base_pipeline_config import BasePipelineConfig
from invoke_training.config.data.data_loader_config import DreamboothSDDataLoaderConfig, ImageCaptionSDDataLoaderConfig
from invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig, ProdigyOptimizerConfig
class SdxlLoraConfig(BasePipelineConfig):
type: Literal["SDXL_LORA"] = "SDXL_LORA"
model: str = "stabilityai/stable-diffusion-xl-base-1.0"
"""Name or path of the base model to train. Can be in diffusers format, or a single stable diffusion checkpoint
file. (E.g. 'stabilityai/stable-diffusion-xl-base-1.0', '/path/to/JuggernautXL.safetensors', etc. )
"""
hf_variant: str | None = "fp16"
"""The Hugging Face Hub model variant to use. Only applies if `model` is a Hugging Face Hub model name.
"""
# Note: Pydantic handles mutable default values well:
# https://docs.pydantic.dev/latest/concepts/models/#fields-with-non-hashable-default-values
base_embeddings: dict[str, str] = {}
"""A mapping of embedding tokens to trained embedding file paths. These embeddings will be applied to the base model
before training.
Example:
```
base_embeddings = {
"bruce_the_gnome": "/path/to/bruce_the_gnome.safetensors",
}
```
Consider also adding the embedding tokens to the `data_loader.caption_prefix` if they are not already present in the
dataset captions.
Note that the embeddings themselves are not fine-tuned further, but they will impact the LoRA model training if they
are referenced in the dataset captions. The list of embeddings provided here should be the same list used at
generation time with the resultant LoRA model.
"""
lora_checkpoint_format: Literal["invoke_peft", "kohya"] = "kohya"
"""The format of the LoRA checkpoint to save. Choose between `invoke_peft` or `kohya`."""
train_unet: bool = True
"""Whether to add LoRA layers to the UNet model and train it.
"""
train_text_encoder: bool = True
"""Whether to add LoRA layers to the text encoder and train it.
"""
optimizer: AdamOptimizerConfig | ProdigyOptimizerConfig = AdamOptimizerConfig()
text_encoder_learning_rate: float | None = None
"""The learning rate to use for the text encoder model. If set, this overrides the optimizer's default learning
rate. Set to null or 0 to use the optimizer's default learning rate.
"""
unet_learning_rate: float | None = None
"""The learning rate to use for the UNet model. If set, this overrides the optimizer's default learning rate.
Set to null or 0 to use the optimizer's default learning rate.
"""
lr_scheduler: Literal[
"linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"
] = "constant"
lr_warmup_steps: int = 0
"""The number of warmup steps in the learning rate scheduler. Only applied to schedulers that support warmup.
See lr_scheduler.
"""
min_snr_gamma: float | None = 5.0
"""Min-SNR weighting for diffusion training was introduced in https://arxiv.org/abs/2303.09556. This strategy
improves the speed of training convergence by adjusting the weight of each sample.
`min_snr_gamma` acts like an an upper bound on the weight of samples with low noise levels.
If `None`, then Min-SNR weighting will not be applied. If enabled, the recommended value is `min_snr_gamma = 5.0`.
"""
lora_rank_dim: int = 4
"""The rank dimension to use for the LoRA layers. Increasing the rank dimension increases the model's expressivity,
but also increases the size of the generated LoRA model.
"""
# The default list of target modules is based on
# https://github.com/huggingface/peft/blob/8665e2b5719faa4e4b91749ddec09442927b53e0/examples/stable_diffusion/train_dreambooth.py#L49C1-L65C87
unet_lora_target_modules: list[str] = UNET_TARGET_MODULES
"""The list of target modules to apply LoRA layers to in the UNet model. The default list will produce a highly
expressive LoRA model.
For a smaller and less expressive LoRA model, the following list is recommended:
```python
unet_lora_target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
```
The list of target modules is passed to Hugging Face's PEFT library. See
[the docs](https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.target_modules) for
details.
"""
text_encoder_lora_target_modules: list[str] = TEXT_ENCODER_TARGET_MODULES
"""The list of target modules to apply LoRA layers to in the text encoder models. The default list will produce a
highly expressive LoRA model.
For a smaller and less expressive LoRA model, the following list is recommended:
```python
text_encoder_lora_target_modules = ["fc1", "fc2", "q_proj", "k_proj", "v_proj", "out_proj"]
```
The list of target modules is passed to Hugging Face's PEFT library. See
[the docs](https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.target_modules) for
details.
"""
cache_text_encoder_outputs: bool = False
"""If True, the text encoder(s) will be applied to all of the captions in the dataset before starting training and
the results will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the
text encoders in VRAM), and speeds up training (don't have to run the text encoders for each training example).
This option can only be enabled if `train_text_encoder == False` and there are no caption augmentations being
applied.
"""
cache_vae_outputs: bool = False
"""If True, the VAE will be applied to all of the images in the dataset before starting training and the results
will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the VAE in VRAM), and
speeds up training (don't have to run the VAE encoding step). This option can only be enabled if all
non-deterministic image augmentations are disabled (i.e. center_crop=True, random_flip=False).
"""
enable_cpu_offload_during_validation: bool = False
"""If True, models will be kept in CPU memory and loaded into GPU memory one-by-one while generating validation
images. This reduces VRAM requirements at the cost of slower generation of validation images.
"""
gradient_accumulation_steps: int = 1
"""The number of gradient steps to accumulate before each weight update. This value is passed to Hugging Face
Accelerate. This is an alternative to increasing the batch size when training with limited VRAM.
"""
weight_dtype: Literal["float32", "float16", "bfloat16"] = "bfloat16"
"""All weights (trainable and fixed) will be cast to this precision. Lower precision dtypes require less VRAM, and
result in faster training, but are more prone to issues with numerical stability.
Recommendations:
- `"float32"`: Use this mode if you have plenty of VRAM available.
- `"bfloat16"`: Use this mode if you have limited VRAM and a GPU that supports bfloat16.
- `"float16"`: Use this mode if you have limited VRAM and a GPU that does not support bfloat16.
See also [`mixed_precision`][invoke_training.pipelines.stable_diffusion_xl.lora.config.SdxlLoraConfig.mixed_precision].
""" # noqa: E501
mixed_precision: Literal["no", "fp16", "bf16", "fp8"] = "no"
"""The mixed precision mode to use.
If mixed precision is enabled, then all non-trainable parameters will be cast to the specified `weight_dtype`, and
trainable parameters are kept in float32 precision to avoid issues with numerical stability.
This value is passed to Hugging Face Accelerate. See
[`accelerate.Accelerator.mixed_precision`](https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.mixed_precision)
for more details.
""" # noqa: E501
xformers: bool = False
"""If true, use xformers for more efficient attention blocks.
"""
gradient_checkpointing: bool = False
"""Whether or not to use gradient checkpointing to save memory at the expense of a slower backward pass. Enabling
gradient checkpointing slows down training by ~20%.
"""
max_checkpoints: int | None = None
"""The maximum number of checkpoints to keep. New checkpoints will replace earlier checkpoints to stay under this
limit. Note that this limit is applied to 'step' and 'epoch' checkpoints separately.
"""
prediction_type: Literal["epsilon", "v_prediction"] | None = None
"""The prediction_type that will be used for training. Choose between 'epsilon' or 'v_prediction' or leave 'None'.
If 'None', the prediction type of the scheduler: `noise_scheduler.config.prediction_type` is used.
"""
max_grad_norm: float | None = None
"""Max gradient norm for clipping. Set to null or 0 for no clipping.
"""
validation_prompts: list[str] = []
"""A list of prompts that will be used to generate images throughout training for the purpose of tracking progress.
See also 'validate_every_n_epochs'.
"""
negative_validation_prompts: list[str] | None = None
"""A list of negative prompts that will be applied when generating validation images. If set, this list should have
the same length as 'validation_prompts'.
"""
num_validation_images_per_prompt: int = 4
"""The number of validation images to generate for each prompt in 'validation_prompts'. Careful, validation can
become quite slow if this number is too large.
"""
train_batch_size: int = 4
"""The training batch size.
"""
use_masks: bool = False
"""If True, image masks will be applied to weight the loss during training. The dataset must contain masks for this
feature to be used.
"""
data_loader: Annotated[
Union[ImageCaptionSDDataLoaderConfig, DreamboothSDDataLoaderConfig], Field(discriminator="type")
]
vae_model: str | None = None
"""The name of the Hugging Face Hub VAE model to train against. This will override the VAE bundled with the base
model (specified by the `model` parameter). This config option is provided for SDXL models, because SDXL shipped
with a VAE that produces NaNs in fp16 mode, so it is common to replace this VAE with a fixed version.
"""
@model_validator(mode="after")
def check_validation_prompts(self):
if self.negative_validation_prompts is not None and len(self.negative_validation_prompts) != len(
self.validation_prompts
):
raise ValueError(
f"The number of validation_prompts ({len(self.validation_prompts)}) must match the number of "
f"negative_validation_prompts ({len(self.negative_validation_prompts)})."
)
return self
================================================
FILE: src/invoke_training/pipelines/stable_diffusion_xl/lora/train.py
================================================
import itertools
import json
import logging
import math
import os
import tempfile
import time
from pathlib import Path
from typing import Literal, Optional, Union
import peft
import torch
import torch.utils.data
from accelerate import Accelerator
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import CLIPPreTrainedModel, CLIPTextModel, PreTrainedTokenizer
from invoke_training._shared.accelerator.accelerator_utils import (
get_dtype_from_str,
initialize_accelerator,
initialize_logging,
)
from invoke_training._shared.checkpoints.checkpoint_tracker import CheckpointTracker
from invoke_training._shared.data.data_loaders.dreambooth_sd_dataloader import build_dreambooth_sd_dataloader
from invoke_training._shared.data.data_loaders.image_caption_sd_dataloader import build_image_caption_sd_dataloader
from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import log_aspect_ratio_buckets
from invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache
from invoke_training._shared.data.utils.resolution import Resolution
from invoke_training._shared.optimizer.optimizer_utils import initialize_optimizer
from invoke_training._shared.stable_diffusion.lora_checkpoint_utils import (
save_sdxl_kohya_checkpoint,
save_sdxl_peft_checkpoint,
)
from invoke_training._shared.stable_diffusion.min_snr_weighting import compute_snr
from invoke_training._shared.stable_diffusion.model_loading_utils import load_models_sdxl
from invoke_training._shared.stable_diffusion.tokenize_captions import tokenize_captions
from invoke_training._shared.stable_diffusion.validation import generate_validation_images_sdxl
from invoke_training._shared.utils.import_xformers import import_xformers
from invoke_training.config.data.data_loader_config import DreamboothSDDataLoaderConfig, ImageCaptionSDDataLoaderConfig
from invoke_training.pipelines.callbacks import ModelCheckpoint, ModelType, PipelineCallbacks, TrainingCheckpoint
from invoke_training.pipelines.stable_diffusion.lora.train import cache_vae_outputs
from invoke_training.pipelines.stable_diffusion_xl.lora.config import SdxlLoraConfig
def _save_sdxl_lora_checkpoint(
epoch: int,
step: int,
unet: peft.PeftModel | None,
text_encoder_1: peft.PeftModel | None,
text_encoder_2: peft.PeftModel | None,
logger: logging.Logger,
checkpoint_tracker: CheckpointTracker,
lora_checkpoint_format: Literal["invoke_peft", "kohya"],
callbacks: list[PipelineCallbacks] | None,
):
# Prune checkpoints and get new checkpoint path.
num_pruned = checkpoint_tracker.prune(1)
if num_pruned > 0:
logger.info(f"Pruned {num_pruned} checkpoint(s).")
save_path = checkpoint_tracker.get_path(epoch=epoch, step=step)
if lora_checkpoint_format == "invoke_peft":
model_type = ModelType.SD1_LORA_PEFT
save_sdxl_peft_checkpoint(
Path(save_path), unet=unet, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2
)
elif lora_checkpoint_format == "kohya":
model_type = ModelType.SD1_LORA_KOHYA
save_sdxl_kohya_checkpoint(
Path(save_path), unet=unet, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2
)
else:
raise ValueError(f"Unsupported lora_checkpoint_format: '{lora_checkpoint_format}'.")
if callbacks is not None:
for cb in callbacks:
cb.on_save_checkpoint(
TrainingCheckpoint(
models=[ModelCheckpoint(file_path=save_path, model_type=model_type)], epoch=epoch, step=step
)
)
def _build_data_loader(
data_loader_config: Union[ImageCaptionSDDataLoaderConfig, DreamboothSDDataLoaderConfig],
batch_size: int,
use_masks: bool = False,
text_encoder_output_cache_dir: Optional[str] = None,
vae_output_cache_dir: Optional[str] = None,
shuffle: bool = True,
sequential_batching: bool = False,
) -> DataLoader:
if data_loader_config.type == "IMAGE_CAPTION_SD_DATA_LOADER":
return build_image_caption_sd_dataloader(
config=data_loader_config,
batch_size=batch_size,
use_masks=use_masks,
text_encoder_output_cache_dir=text_encoder_output_cache_dir,
text_encoder_cache_field_to_output_field={
"prompt_embeds": "prompt_embeds",
"pooled_prompt_embeds": "pooled_prompt_embeds",
},
vae_output_cache_dir=vae_output_cache_dir,
shuffle=shuffle,
)
elif data_loader_config.type == "DREAMBOOTH_SD_DATA_LOADER":
if use_masks:
raise ValueError("Masks are not yet supported for DreamBooth data loaders.")
return build_dreambooth_sd_dataloader(
config=data_loader_config,
batch_size=batch_size,
text_encoder_output_cache_dir=text_encoder_output_cache_dir,
text_encoder_cache_field_to_output_field={
"prompt_embeds": "prompt_embeds",
"pooled_prompt_embeds": "pooled_prompt_embeds",
},
vae_output_cache_dir=vae_output_cache_dir,
shuffle=shuffle,
sequential_batching=sequential_batching,
)
else:
raise ValueError(f"Unsupported data loader config type: '{data_loader_config.type}'.")
# encode_prompt was adapted from:
# https://github.com/huggingface/diffusers/blob/7b07f9812a58bfa96c06ed8ffe9e6b584286e2fd/examples/text_to_image/train_text_to_image_lora_sdxl.py#L470-L496
def _encode_prompt(text_encoders: list[CLIPPreTrainedModel], prompt_token_ids_list: list[torch.Tensor]):
prompt_embeds_list = []
for i, text_encoder in enumerate(text_encoders):
text_input_ids = prompt_token_ids_list[i]
prompt_embeds = text_encoder(
text_input_ids.to(text_encoder.device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder.
# TODO(ryand): Document this logic more clearly.
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
return prompt_embeds, pooled_prompt_embeds
# TODO(ryand): Cache VAE outputs and text encoder outputs at the same time in a single pass over the dataset.
def cache_text_encoder_outputs(
cache_dir: str,
config: SdxlLoraConfig,
tokenizer_1: PreTrainedTokenizer,
tokenizer_2: PreTrainedTokenizer,
text_encoder_1: CLIPPreTrainedModel,
text_encoder_2: CLIPPreTrainedModel,
):
"""Run the text encoder on all captions in the dataset and cache the results to disk.
Args:
cache_dir (str): The directory where the results will be cached.
config (FinetuneLoRAConfig): Training config.
tokenizer_1 (PreTrainedTokenizer):
tokenizer_2 (PreTrainedTokenizer):
text_encoder_1 (CLIPPreTrainedModel):
text_encoder_2 (CLIPPreTrainedModel):
"""
data_loader = _build_data_loader(
data_loader_config=config.data_loader,
batch_size=config.train_batch_size,
shuffle=False,
sequential_batching=True,
)
cache = TensorDiskCache(cache_dir)
for data_batch in tqdm(data_loader):
caption_token_ids_1 = tokenize_captions(tokenizer_1, data_batch["caption"])
caption_token_ids_2 = tokenize_captions(tokenizer_2, data_batch["caption"])
prompt_embeds, pooled_prompt_embeds = _encode_prompt(
[text_encoder_1, text_encoder_2], [caption_token_ids_1, caption_token_ids_2]
)
# Split batch before caching.
for i in range(len(data_batch["id"])):
embeds = {
"prompt_embeds": prompt_embeds[i],
"pooled_prompt_embeds": pooled_prompt_embeds[i],
}
cache.save(data_batch["id"][i], embeds)
def train_forward( # noqa: C901
accelerator: Accelerator,
data_batch: dict,
vae: AutoencoderKL,
noise_scheduler: DDPMScheduler,
tokenizer_1: PreTrainedTokenizer,
tokenizer_2: PreTrainedTokenizer,
text_encoder_1: CLIPPreTrainedModel,
text_encoder_2: CLIPPreTrainedModel,
unet: UNet2DConditionModel,
weight_dtype: torch.dtype,
resolution: int | tuple[int, int],
use_masks: bool = False,
prediction_type=None,
min_snr_gamma: float | None = None,
):
"""Run the forward training pass for a single data_batch.
Returns:
torch.Tensor: Loss
"""
# Convert images to latent space.
# The VAE output may have been cached and included in the data_batch. If not, we calculate it here.
latents = data_batch.get("vae_output", None)
if latents is None:
latents = vae.encode(data_batch["image"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents.
noise = torch.randn_like(latents)
batch_size = latents.shape[0]
# Sample a random timestep for each image.
timesteps = torch.randint(
0,
noise_scheduler.config.num_train_timesteps,
(batch_size,),
device=latents.device,
)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep (this is the forward diffusion
# process).
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# compute_time_ids was copied from:
# https://github.com/huggingface/diffusers/blob/7b07f9812a58bfa96c06ed8ffe9e6b584286e2fd/examples/text_to_image/train_text_to_image_lora_sdxl.py#L1033-L1039
# "time_ids" may seem like a weird naming choice. The name comes from the diffusers SDXL implementation. Presumably,
# it is a result of the fact that the original size and crop values get concatenated with the time embeddings.
def compute_time_ids(original_size, crops_coords_top_left):
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
target_size = Resolution.parse(resolution).to_tuple()
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids])
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
return add_time_ids
add_time_ids = torch.cat(
[compute_time_ids(s, c) for s, c in zip(data_batch["original_size_hw"], data_batch["crop_top_left_yx"])]
)
unet_conditions = {"time_ids": add_time_ids}
# Get the text embedding for conditioning.
# The text encoder output may have been cached and included in the data_batch. If not, we calculate it here.
if "prompt_embeds" in data_batch:
prompt_embeds = data_batch["prompt_embeds"]
pooled_prompt_embeds = data_batch["pooled_prompt_embeds"]
else:
caption_token_ids_1 = tokenize_captions(tokenizer_1, data_batch["caption"])
caption_token_ids_2 = tokenize_captions(tokenizer_2, data_batch["caption"])
prompt_embeds, pooled_prompt_embeds = _encode_prompt(
[text_encoder_1, text_encoder_2], [caption_token_ids_1, caption_token_ids_2]
)
prompt_embeds = prompt_embeds.to(dtype=weight_dtype)
pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=weight_dtype)
unet_conditions["text_embeds"] = pooled_prompt_embeds
# Get the target for loss depending on the prediction type.
if prediction_type is not None:
# Set the prediction_type of scheduler if it's defined in config.
noise_scheduler.register_to_config(prediction_type=prediction_type)
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# Predict the noise residual.
model_pred = unet(noisy_latents, timesteps, prompt_embeds, added_cond_kwargs=unet_conditions).sample
min_snr_weights = None
if min_snr_gamma is not None:
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps)
# Note: We divide by snr here per Section 4.2 of the paper, since we are predicting the noise instead of x_0.
# w_t = min(1, SNR(t)) / SNR(t)
min_snr_weights = torch.clamp(snr, max=min_snr_gamma) / snr
if noise_scheduler.config.prediction_type == "epsilon":
pass
elif noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
min_snr_weights = min_snr_weights + 1
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction="none")
if use_masks:
# TODO(ryand): As a future performance optimization, we may want to do this resizing in the dataloader.
mask = data_batch["mask"].to(dtype=loss.dtype, device=loss.device)
_, _, latent_h, latent_w = loss.shape
mask = torch.nn.functional.interpolate(mask, size=(latent_h, latent_w), mode="nearest")
loss = loss * mask
# Mean-reduce the loss along all dimensions except for the batch dimension.
loss = loss.mean(dim=list(range(1, len(loss.shape))))
# Apply min_snr_weights.
if min_snr_weights is not None:
loss = loss * min_snr_weights
# Apply per-example loss weights.
if "loss_weight" in data_batch:
loss = loss * data_batch["loss_weight"]
return loss.mean()
def train(config: SdxlLoraConfig, callbacks: list[PipelineCallbacks] | None = None): # noqa: C901
# Give a clear error message if an unsupported base model was chosen.
# TODO(ryan): Update this check to work with single-file SD checkpoints.
# check_base_model_version(
# {BaseModelVersionEnum.STABLE_DIFFUSION_SDXL_BASE},
# config.model,
# local_files_only=False,
# )
# Create a timestamped directory for all outputs.
out_dir = os.path.join(config.base_output_dir, f"{time.time()}")
ckpt_dir = os.path.join(out_dir, "checkpoints")
os.makedirs(ckpt_dir)
accelerator = initialize_accelerator(
out_dir, config.gradient_accumulation_steps, config.mixed_precision, config.report_to
)
logger = initialize_logging(os.path.basename(__file__), accelerator)
# Set the accelerate seed.
if config.seed is not None:
set_seed(config.seed)
# Log the accelerator configuration from every process to help with debugging.
logger.info(accelerator.state, main_process_only=False)
logger.info("Starting Training.")
logger.info(f"Configuration:\n{json.dumps(config.dict(), indent=2, default=str)}")
logger.info(f"Output dir: '{out_dir}'")
# Write the configuration to disk.
with open(os.path.join(out_dir, "config.json"), "w") as f:
json.dump(config.dict(), f, indent=2, default=str)
weight_dtype = get_dtype_from_str(config.weight_dtype)
logger.info("Loading models.")
tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, unet = load_models_sdxl(
logger=logger,
model_name_or_path=config.model,
hf_variant=config.hf_variant,
vae_model=config.vae_model,
base_embeddings=config.base_embeddings,
dtype=weight_dtype,
)
if config.xformers:
import_xformers()
# TODO(ryand): There is a known issue if xformers is enabled when training in mixed precision where xformers
# will fail because Q, K, V have different dtypes.
unet.enable_xformers_memory_efficient_attention()
vae.enable_xformers_memory_efficient_attention()
# Prepare text encoder output cache.
text_encoder_output_cache_dir_name = None
if config.cache_text_encoder_outputs:
# TODO(ryand): Think about how to better check if it is safe to cache the text encoder outputs. Currently, there
# are a number of configurations that would cause variation in the text encoder outputs and should not be used
# with caching.
if config.train_text_encoder:
raise ValueError("'cache_text_encoder_outputs' and 'train_text_encoder' cannot both be True.")
# We use a temporary directory for the cache. The directory will automatically be cleaned up when
# tmp_text_encoder_output_cache_dir is destroyed.
tmp_text_encoder_output_cache_dir = tempfile.TemporaryDirectory()
text_encoder_output_cache_dir_name = tmp_text_encoder_output_cache_dir.name
if accelerator.is_local_main_process:
# Only the main process should populate the cache.
logger.info(f"Generating text encoder output cache ('{text_encoder_output_cache_dir_name}').")
text_encoder_1.to(accelerator.device, dtype=weight_dtype)
text_encoder_2.to(accelerator.device, dtype=weight_dtype)
cache_text_encoder_outputs(
text_encoder_output_cache_dir_name, config, tokenizer_1, tokenizer_2, text_encoder_1, text_encoder_2
)
# Move the text_encoders back to the CPU, because they are not needed for training.
text_encoder_1.to("cpu")
text_encoder_2.to("cpu")
accelerator.wait_for_everyone()
else:
text_encoder_1.to(accelerator.device, dtype=weight_dtype)
text_encoder_2.to(accelerator.device, dtype=weight_dtype)
# Prepare VAE output cache.
vae_output_cache_dir_name = None
if config.cache_vae_outputs:
if config.data_loader.random_flip:
raise ValueError("'cache_vae_outputs' cannot be True if 'random_flip' is True.")
if not config.data_loader.center_crop:
raise ValueError("'cache_vae_outputs' cannot be True if 'center_crop' is False.")
# We use a temporary directory for the cache. The directory will automatically be cleaned up when
# tmp_vae_output_cache_dir is destroyed.
tmp_vae_output_cache_dir = tempfile.TemporaryDirectory()
vae_output_cache_dir_name = tmp_vae_output_cache_dir.name
if accelerator.is_local_main_process:
# Only the main process should to populate the cache.
logger.info(f"Generating VAE output cache ('{vae_output_cache_dir_name}').")
vae.to(accelerator.device, dtype=weight_dtype)
data_loader = _build_data_loader(
data_loader_config=config.data_loader,
batch_size=config.train_batch_size,
use_masks=config.use_masks,
shuffle=False,
sequential_batching=True,
)
cache_vae_outputs(vae_output_cache_dir_name, data_loader, vae)
# Move the VAE back to the CPU, because it is not needed for training.
vae.to("cpu")
accelerator.wait_for_everyone()
else:
vae.to(accelerator.device, dtype=weight_dtype)
unet.to(accelerator.device, dtype=weight_dtype)
# Add LoRA layers to the models being trained.
trainable_param_groups = []
all_trainable_models: list[peft.PeftModel] = []
def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = None) -> peft.PeftModel:
peft_model = peft.get_peft_model(model, lora_config)
peft_model.print_trainable_parameters()
# Populate `trainable_param_groups`, to be passed to the optimizer.
param_group = {"params": list(filter(lambda p: p.requires_grad, peft_model.parameters()))}
if lr is not None:
param_group["lr"] = lr
trainable_param_groups.append(param_group)
# Populate all_trainable_models.
all_trainable_models.append(peft_model)
peft_model.train()
return peft_model
if config.train_unet:
unet_lora_config = peft.LoraConfig(
r=config.lora_rank_dim,
# TODO(ryand): Diffusers uses lora_alpha=config.lora_rank_dim. Is that preferred?
lora_alpha=1.0,
target_modules=config.unet_lora_target_modules,
)
unet = inject_lora_layers(unet, unet_lora_config, lr=config.unet_learning_rate)
if config.train_text_encoder:
text_encoder_lora_config = peft.LoraConfig(
r=config.lora_rank_dim,
lora_alpha=1.0,
# init_lora_weights="gaussian",
target_modules=config.text_encoder_lora_target_modules,
)
text_encoder_1 = inject_lora_layers(
text_encoder_1, text_encoder_lora_config, lr=config.text_encoder_learning_rate
)
text_encoder_2 = inject_lora_layers(
text_encoder_2, text_encoder_lora_config, lr=config.text_encoder_learning_rate
)
# If mixed_precision is enabled, cast all trainable params to float32.
if config.mixed_precision != "no":
for trainable_model in all_trainable_models:
for param in trainable_model.parameters():
if param.requires_grad:
param.data = param.to(torch.float32)
if config.gradient_checkpointing:
# We want to enable gradient checkpointing in the UNet regardless of whether it is being trained.
unet.enable_gradient_checkpointing()
# unet must be in train() mode for gradient checkpointing to take effect.
# At the time of writing, the unet dropout probabilities default to 0, so putting the unet in train mode does
# not change its forward behavior.
unet.train()
if config.train_text_encoder:
for te in [text_encoder_1, text_encoder_2]:
te.gradient_checkpointing_enable()
# The text encoders must be in train() mode for gradient checkpointing to take effect. This should
# already be the case, since we are training the text_encoders, be we do it explicitly to make it clear
# that this is required.
# At the time of writing, the text encoder dropout probabilities default to 0, so putting the text
# encoders in train mode does not change their forward behavior.
te.train()
# Set requires_grad = True on the first parameters of the text encoders. Without this, the text encoder
# LoRA weights would have 0 gradients, and so would not get trained. Note that the set of
# trainable_param_groups has already been populated - the embeddings will not be trained.
te.text_model.embeddings.requires_grad_(True)
optimizer = initialize_optimizer(config.optimizer, trainable_param_groups)
data_loader = _build_data_loader(
data_loader_config=config.data_loader,
batch_size=config.train_batch_size,
use_masks=config.use_masks,
text_encoder_output_cache_dir=text_encoder_output_cache_dir_name,
vae_output_cache_dir=vae_output_cache_dir_name,
)
log_aspect_ratio_buckets(logger=logger, batch_sampler=data_loader.batch_sampler)
assert sum([config.max_train_steps is not None, config.max_train_epochs is not None]) == 1
assert sum([config.save_every_n_steps is not None, config.save_every_n_epochs is not None]) == 1
assert sum([config.validate_every_n_steps is not None, config.validate_every_n_epochs is not None]) == 1
# A "step" represents a single weight update operation (i.e. takes into account gradient accumulation steps).
# math.ceil(...) is used in calculating the num_steps_per_epoch, because by default an optimizer step is taken when
# the end of the dataloader is reached, even if gradient_accumulation_steps hasn't been reached.
num_steps_per_epoch = math.ceil(len(data_loader) / config.gradient_accumulation_steps)
num_train_steps = config.max_train_steps or config.max_train_epochs * num_steps_per_epoch
num_train_epochs = math.ceil(num_train_steps / num_steps_per_epoch)
# TODO(ryand): Test in a distributed training environment and more clearly document the rationale for scaling steps
# by the number of processes. This scaling logic was copied from the diffusers example training code, but it appears
# in many places so I don't know where it originated. Internally, accelerate makes one LR scheduler step per process
# (https://github.com/huggingface/accelerate/blame/49cb83a423f2946059117d8bb39b7c8747d29d80/src/accelerate/scheduler.py#L72-L82),
# so the scaling here simply reverses that behaviour.
lr_scheduler: torch.optim.lr_scheduler.LRScheduler = get_scheduler(
config.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=config.lr_warmup_steps * accelerator.num_processes,
num_training_steps=num_train_steps * accelerator.num_processes,
)
prepared_result: tuple[
UNet2DConditionModel,
peft.PeftModel | CLIPTextModel,
peft.PeftModel | CLIPTextModel,
torch.optim.Optimizer,
torch.utils.data.DataLoader,
torch.optim.lr_scheduler.LRScheduler,
] = accelerator.prepare(
unet,
text_encoder_1,
text_encoder_2,
optimizer,
data_loader,
lr_scheduler,
# Disable automatic device placement for text_encoder if the text encoder outputs were cached.
device_placement=[
True,
not config.cache_text_encoder_outputs,
not config.cache_text_encoder_outputs,
True,
True,
True,
],
)
unet, text_encoder_1, text_encoder_2, optimizer, data_loader, lr_scheduler = prepared_result
if accelerator.is_main_process:
accelerator.init_trackers("lora_training")
# Tensorboard uses markdown formatting, so we wrap the config json in a code block.
accelerator.log({"configuration": f"```json\n{json.dumps(config.dict(), indent=2, default=str)}\n```\n"})
checkpoint_tracker = CheckpointTracker(
base_dir=ckpt_dir,
prefix="checkpoint",
max_checkpoints=config.max_checkpoints,
extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None,
)
# Train!
total_batch_size = config.train_batch_size * accelerator.num_processes * config.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num batches = {len(data_loader)}")
logger.info(f" Instantaneous batch size per device = {config.train_batch_size}")
logger.info(f" Gradient accumulation steps = {config.gradient_accumulation_steps}")
logger.info(f" Parallel processes = {accelerator.num_processes}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Total optimization steps = {num_train_steps}")
logger.info(f" Total epochs = {num_train_epochs}")
global_step = 0
first_epoch = 0
completed_epochs = 0
progress_bar = tqdm(
range(global_step, num_train_steps),
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)
progress_bar.set_description("Steps")
def save_checkpoint(num_completed_epochs: int, num_completed_steps: int):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
_save_sdxl_lora_checkpoint(
epoch=num_completed_epochs,
step=num_completed_steps,
unet=unet if config.train_unet else None,
text_encoder_1=text_encoder_1 if config.train_text_encoder else None,
text_encoder_2=text_encoder_2 if config.train_text_encoder else None,
logger=logger,
checkpoint_tracker=checkpoint_tracker,
lora_checkpoint_format=config.lora_checkpoint_format,
callbacks=callbacks,
)
accelerator.wait_for_everyone()
def validate(num_completed_epochs: int, num_completed_steps: int):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
generate_validation_images_sdxl(
epoch=num_completed_epochs,
step=num_completed_steps,
out_dir=out_dir,
accelerator=accelerator,
vae=vae,
text_encoder_1=text_encoder_1,
text_encoder_2=text_encoder_2,
tokenizer_1=tokenizer_1,
tokenizer_2=tokenizer_2,
noise_scheduler=noise_scheduler,
unet=unet,
config=config,
logger=logger,
callbacks=callbacks,
)
accelerator.wait_for_everyone()
for epoch in range(first_epoch, num_train_epochs):
train_loss = 0.0
for data_batch_idx, data_batch in enumerate(data_loader):
with accelerator.accumulate(unet, text_encoder_1, text_encoder_2):
loss = train_forward(
accelerator=accelerator,
data_batch=data_batch,
vae=vae,
noise_scheduler=noise_scheduler,
tokenizer_1=tokenizer_1,
tokenizer_2=tokenizer_2,
text_encoder_1=text_encoder_1,
text_encoder_2=text_encoder_2,
unet=unet,
weight_dtype=weight_dtype,
resolution=config.data_loader.resolution,
use_masks=config.use_masks,
prediction_type=config.prediction_type,
min_snr_gamma=config.min_snr_gamma,
)
# Gather the losses across all processes for logging (if we use distributed training).
# TODO(ryand): Test that this works properly with distributed training.
avg_loss = accelerator.gather(loss.repeat(config.train_batch_size)).mean()
train_loss += avg_loss.item() / config.gradient_accumulation_steps
# Backpropagate.
accelerator.backward(loss)
if accelerator.sync_gradients and config.max_grad_norm is not None:
params_to_clip = itertools.chain.from_iterable([m.parameters() for m in all_trainable_models])
accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Checks if the accelerator has performed an optimization step behind the scenes.
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
completed_epochs = epoch if (data_batch_idx + 1) < len(data_loader) else epoch + 1
log = {"train_loss": train_loss}
lrs = lr_scheduler.get_last_lr()
if config.train_unet:
# When training the UNet, it will always be the first parameter group.
log["lr/unet"] = float(lrs[0])
if config.optimizer.optimizer_type == "Prodigy":
log["lr/d*lr/unet"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
if config.train_text_encoder:
# When training the text encoder, it will always be the last parameter group.
log["lr/text_encoder"] = float(lrs[-1])
if config.optimizer.optimizer_type == "Prodigy":
log["lr/d*lr/text_encoder"] = optimizer.param_groups[-1]["d"] * optimizer.param_groups[-1]["lr"]
accelerator.log(log, step=global_step)
train_loss = 0.0
# global_step represents the *number of completed steps* at this point.
if config.save_every_n_steps is not None and global_step % config.save_every_n_steps == 0:
save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step)
if (
config.validate_every_n_steps is not None
and global_step % config.validate_every_n_steps == 0
and len(config.validation_prompts) > 0
):
validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step)
logs = {
"step_loss": loss.detach().item(),
"lr": lr_scheduler.get_last_lr()[0],
}
progress_bar.set_postfix(**logs)
if global_step >= num_train_steps:
break
# Save a checkpoint every n epochs.
if config.save_every_n_epochs is not None and completed_epochs % config.save_every_n_epochs == 0:
save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step)
# Generate validation images every n epochs.
if (
config.validate_every_n_epochs is not None
and completed_epochs % config.validate_every_n_epochs == 0
and len(config.validation_prompts) > 0
):
validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step)
accelerator.end_training()
================================================
FILE: src/invoke_training/pipelines/stable_diffusion_xl/lora_and_textual_inversion/__init__.py
================================================
================================================
FILE: src/invoke_training/pipelines/stable_diffusion_xl/lora_and_textual_inversion/config.py
================================================
from typing import Literal
from pydantic import model_validator
from invoke_training.config.base_pipeline_config import BasePipelineConfig
from invoke_training.config.data.data_loader_config import TextualInversionSDDataLoaderConfig
from invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig, ProdigyOptimizerConfig
class SdxlLoraAndTextualInversionConfig(BasePipelineConfig):
type: Literal["SDXL_LORA_AND_TEXTUAL_INVERSION"] = "SDXL_LORA_AND_TEXTUAL_INVERSION"
model: str = "stabilityai/stable-diffusion-xl-base-1.0"
"""Name or path of the base model to train. Can be in diffusers format, or a single stable diffusion checkpoint
file. (E.g. 'stabilityai/stable-diffusion-xl-base-1.0', '/path/to/JuggernautXL.safetensors', etc. )
"""
hf_variant: str | None = "fp16"
"""The Hugging Face Hub model variant to use. Only applies if `model` is a Hugging Face Hub model name.
"""
lora_checkpoint_format: Literal["invoke_peft", "kohya"] = "kohya"
"""The format of the LoRA checkpoint to save. Choose between `invoke_peft` or `kohya`."""
# Helpful discussion for understanding how this works at inference time:
# https://github.com/huggingface/diffusers/pull/3144#discussion_r1172413509
num_vectors: int = 1
"""Note: `num_vectors` can be overridden by `initial_phrase`.
The number of textual inversion embedding vectors that will be used to learn the concept.
Increasing the `num_vectors` enables the model to learn more complex concepts, but has the following drawbacks:
- greater risk of overfitting
- increased size of the resulting output file
- consumes more of the prompt capacity at inference time
Typical values for `num_vectors` are in the range [1, 16].
As a rule of thumb, `num_vectors` can be increased as the size of the dataset increases (without overfitting).
"""
placeholder_token: str
"""The special word to associate the learned embeddings with. Choose a unique token that is unlikely to already
exist in the tokenizer's vocabulary.
"""
initializer_token: str | None = None
"""A vocabulary token to use as an initializer for the placeholder token. It should be a single word that roughly
describes the object or style that you're trying to train on. Must map to a single tokenizer token.
For example, if you are training on a dataset of images of your pet dog, a good choice would be `dog`.
"""
initial_phrase: str | None = None
"""Note: Exactly one of `initializer_token` or `initial_phrase` should be set.
A phrase that will be used to initialize the placeholder token embedding. The phrase will be tokenized, and the
corresponding embeddings will be used to initialize the placeholder tokens. The number of embedding vectors will be
inferred from the length of the tokenized phrase, so keep the phrase short. The consequences of training a large
number of embedding vectors are discussed in the `num_vectors` field documentation.
For example, if you are training on a dataset of images of pokemon, you might use `pokemon sketch white background`.
"""
train_unet: bool = True
"""Whether to add LoRA layers to the UNet model and train it.
"""
train_text_encoder: bool = True
"""Whether to add LoRA layers to the text encoder and train it.
"""
train_ti: bool = True
"""Whether to train the textual inversion embeddings."""
ti_train_steps_ratio: float | None = None
"""The fraction of the total training steps for which the TI embeddings will be trained. For example, if we are
training for a total of 5000 steps and `ti_train_steps_ratio=0.5`, then the TI embeddings will be trained for 2500
steps and the will be frozen for the remaining steps.
If `None`, then the TI embeddings will be trained for the entire duration of training.
"""
optimizer: AdamOptimizerConfig | ProdigyOptimizerConfig = AdamOptimizerConfig()
text_encoder_learning_rate: float | None = 1e-5
"""The learning rate to use for the text encoder model. Set to null or 0 to use the optimizer's default learning
rate.
"""
unet_learning_rate: float | None = 1e-4
"""The learning rate to use for the UNet model. Set to null or 0 to use the optimizer's default learning rate.
"""
textual_inversion_learning_rate: float | None = 1e-3
"""The learning rate to use for textual inversion training of the embeddings. Set to null or 0 to use the
optimizer's default learning rate.
"""
lr_scheduler: Literal[
"linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"
] = "constant"
lr_warmup_steps: int = 0
"""The number of warmup steps in the learning rate scheduler. Only applied to schedulers that support warmup.
See lr_scheduler.
"""
min_snr_gamma: float | None = 5.0
"""Min-SNR weighting for diffusion training was introduced in https://arxiv.org/abs/2303.09556. This strategy
improves the speed of training convergence by adjusting the weight of each sample.
`min_snr_gamma` acts like an an upper bound on the weight of samples with low noise levels.
If `None`, then Min-SNR weighting will not be applied. If enabled, the recommended value is `min_snr_gamma = 5.0`.
"""
lora_rank_dim: int = 4
"""The rank dimension to use for the LoRA layers. Increasing the rank dimension increases the model's expressivity,
but also increases the size of the generated LoRA model.
"""
cache_text_encoder_outputs: bool = False
"""If True, the text encoder(s) will be applied to all of the captions in the dataset before starting training and
the results will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the
text encoders in VRAM), and speeds up training (don't have to run the text encoders for each training example).
This option can only be enabled if `train_text_encoder == False` and there are no caption augmentations being
applied.
"""
cache_vae_outputs: bool = False
"""If True, the VAE will be applied to all of the images in the dataset before starting training and the results
will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the VAE in VRAM), and
speeds up training (don't have to run the VAE encoding step). This option can only be enabled if all
non-deterministic image augmentations are disabled (i.e. center_crop=True, random_flip=False).
"""
enable_cpu_offload_during_validation: bool = False
"""If True, models will be kept in CPU memory and loaded into GPU memory one-by-one while generating validation
images. This reduces VRAM requirements at the cost of slower generation of validation images.
"""
gradient_accumulation_steps: int = 1
"""The number of gradient steps to accumulate before each weight update. This value is passed to Hugging Face
Accelerate. This is an alternative to increasing the batch size when training with limited VRAM.
"""
weight_dtype: Literal["float32", "float16", "bfloat16"] = "bfloat16"
"""All weights (trainable and fixed) will be cast to this precision. Lower precision dtypes require less VRAM, and
result in faster training, but are more prone to issues with numerical stability.
Recommendations:
- `"float32"`: Use this mode if you have plenty of VRAM available.
- `"bfloat16"`: Use this mode if you have limited VRAM and a GPU that supports bfloat16.
- `"float16"`: Use this mode if you have limited VRAM and a GPU that does not support bfloat16.
See also [`mixed_precision`][invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config.SdxlLoraAndTextualInversionConfig.mixed_precision].
""" # noqa: E501
mixed_precision: Literal["no", "fp16", "bf16", "fp8"] = "no"
"""The mixed precision mode to use.
If mixed precision is enabled, then all non-trainable parameters will be cast to the specified `weight_dtype`, and
trainable parameters are kept in float32 precision to avoid issues with numerical stability.
This value is passed to Hugging Face Accelerate. See
[`accelerate.Accelerator.mixed_precision`](https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.mixed_precision)
for more details.
""" # noqa: E501
xformers: bool = False
"""If true, use xformers for more efficient attention blocks.
"""
gradient_checkpointing: bool = False
"""Whether or not to use gradient checkpointing to save memory at the expense of a slower backward pass. Enabling
gradient checkpointing slows down training by ~20%.
"""
max_checkpoints: int | None = None
"""The maximum number of checkpoints to keep. New checkpoints will replace earlier checkpoints to stay under this
limit. Note that this limit is applied to 'step' and 'epoch' checkpoints separately.
"""
prediction_type: Literal["epsilon", "v_prediction"] | None = None
"""The prediction_type that will be used for training. Choose between 'epsilon' or 'v_prediction' or leave 'None'.
If 'None', the prediction type of the scheduler: `noise_scheduler.config.prediction_type` is used.
"""
max_grad_norm: float | None = None
"""Max gradient norm for clipping. Set to null or 0 for no clipping.
"""
validation_prompts: list[str] = []
"""A list of prompts that will be used to generate images throughout training for the purpose of tracking progress.
"""
negative_validation_prompts: list[str] | None = None
"""A list of negative prompts that will be applied when generating validation images. If set, this list should have
the same length as 'validation_prompts'.
"""
num_validation_images_per_prompt: int = 4
"""The number of validation images to generate for each prompt in 'validation_prompts'. Careful, validation can
become quite slow if this number is too large.
"""
train_batch_size: int = 4
"""The training batch size.
"""
use_masks: bool = False
"""If True, image masks will be applied to weight the loss during training. The dataset must contain masks for this
feature to be used.
"""
data_loader: TextualInversionSDDataLoaderConfig
"""The data configuration.
See
[`TextualInversionSDDataLoaderConfig`][invoke_training.config.data.data_loader_config.TextualInversionSDDataLoaderConfig]
for details.
"""
vae_model: str | None = None
"""The name of the Hugging Face Hub VAE model to train against. This will override the VAE bundled with the base
model (specified by the `model` parameter). This config option is provided for SDXL models, because SDXL shipped
with a VAE that produces NaNs in fp16 mode, so it is common to replace this VAE with a fixed version.
"""
@model_validator(mode="after")
def check_validation_prompts(self):
if self.negative_validation_prompts is not None and len(self.negative_validation_prompts) != len(
self.validation_prompts
):
raise ValueError(
f"The number of validation_prompts ({len(self.validation_prompts)}) must match the number of "
f"negative_validation_prompts ({len(self.negative_validation_prompts)})."
)
return self
================================================
FILE: src/invoke_training/pipelines/stable_diffusion_xl/lora_and_textual_inversion/train.py
================================================
import itertools
import json
import logging
import math
import os
import time
from pathlib import Path
from typing import Literal
import peft
import torch
import torch.utils.data
from accelerate import Accelerator
from accelerate.utils import set_seed
from diffusers import UNet2DConditionModel
from diffusers.optimization import get_scheduler
from tqdm.auto import tqdm
from transformers import CLIPTextModel
from invoke_training._shared.accelerator.accelerator_utils import (
get_dtype_from_str,
initialize_accelerator,
initialize_logging,
)
from invoke_training._shared.checkpoints.checkpoint_tracker import CheckpointTracker
from invoke_training._shared.checkpoints.serialization import save_state_dict
from invoke_training._shared.data.data_loaders.textual_inversion_sd_dataloader import (
build_textual_inversion_sd_dataloader,
)
from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import log_aspect_ratio_buckets
from invoke_training._shared.optimizer.optimizer_utils import initialize_optimizer
from invoke_training._shared.stable_diffusion.lora_checkpoint_utils import (
TEXT_ENCODER_TARGET_MODULES,
UNET_TARGET_MODULES,
save_sdxl_kohya_checkpoint,
save_sdxl_peft_checkpoint,
)
from invoke_training._shared.stable_diffusion.model_loading_utils import load_models_sdxl
from invoke_training._shared.stable_diffusion.textual_inversion import restore_original_embeddings
from invoke_training._shared.stable_diffusion.validation import generate_validation_images_sdxl
from invoke_training._shared.utils.import_xformers import import_xformers
from invoke_training.pipelines.callbacks import ModelCheckpoint, ModelType, PipelineCallbacks, TrainingCheckpoint
from invoke_training.pipelines.stable_diffusion_xl.lora.train import train_forward
from invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config import (
SdxlLoraAndTextualInversionConfig,
)
from invoke_training.pipelines.stable_diffusion_xl.textual_inversion.train import _initialize_placeholder_tokens
def _save_sdxl_lora_and_ti_checkpoint(
config: SdxlLoraAndTextualInversionConfig,
epoch: int,
step: int,
unet: peft.PeftModel | None,
text_encoder_1: peft.PeftModel | None,
text_encoder_2: peft.PeftModel | None,
placeholder_token_ids_1: list[int],
placeholder_token_ids_2: list[int],
accelerator: Accelerator,
logger: logging.Logger,
checkpoint_tracker: CheckpointTracker,
lora_checkpoint_format: Literal["invoke_peft", "kohya"],
callbacks: list[PipelineCallbacks] | None,
):
# Prune checkpoints and get new checkpoint path.
num_pruned = checkpoint_tracker.prune(1)
if num_pruned > 0:
logger.info(f"Pruned {num_pruned} checkpoint(s).")
save_path = checkpoint_tracker.get_path(epoch=epoch, step=step)
training_checkpoint = TrainingCheckpoint(models=[], epoch=epoch, step=step)
if lora_checkpoint_format == "invoke_peft":
save_sdxl_peft_checkpoint(
Path(save_path),
unet=unet if config.train_unet else None,
text_encoder_1=text_encoder_1 if config.train_text_encoder else None,
text_encoder_2=text_encoder_2 if config.train_text_encoder else None,
)
training_checkpoint.models.append(ModelCheckpoint(file_path=save_path, model_type=ModelType.SDXL_LORA_PEFT))
elif lora_checkpoint_format == "kohya":
save_sdxl_kohya_checkpoint(
Path(save_path) / "lora.safetensors",
unet=unet if config.train_unet else None,
text_encoder_1=text_encoder_1 if config.train_text_encoder else None,
text_encoder_2=text_encoder_2 if config.train_text_encoder else None,
)
training_checkpoint.models.append(ModelCheckpoint(file_path=save_path, model_type=ModelType.SDXL_LORA_KOHYA))
else:
raise ValueError(f"Unsupported lora_checkpoint_format: '{lora_checkpoint_format}'.")
if config.train_ti:
ti_checkpoint_path = Path(save_path) / "embeddings.safetensors"
learned_embeds_1 = (
accelerator.unwrap_model(text_encoder_1)
.get_input_embeddings()
.weight[min(placeholder_token_ids_1) : max(placeholder_token_ids_1) + 1]
)
learned_embeds_2 = (
accelerator.unwrap_model(text_encoder_2)
.get_input_embeddings()
.weight[min(placeholder_token_ids_2) : max(placeholder_token_ids_2) + 1]
)
learned_embeds_dict = {
"clip_l": learned_embeds_1.detach().cpu().to(dtype=torch.float32),
"clip_g": learned_embeds_2.detach().cpu().to(dtype=torch.float32),
}
save_state_dict(learned_embeds_dict, ti_checkpoint_path)
training_checkpoint.models.append(
ModelCheckpoint(file_path=ti_checkpoint_path, model_type=ModelType.SDXL_TEXTUAL_INVERSION)
)
if callbacks is not None:
for cb in callbacks:
cb.on_save_checkpoint(training_checkpoint)
def train(config: SdxlLoraAndTextualInversionConfig, callbacks: list[PipelineCallbacks] | None = None): # noqa: C901
# Give a clear error message if an unsupported base model was chosen.
# TODO(ryan): Update this check to work with single-file SD checkpoints.
# check_base_model_version(
# {BaseModelVersionEnum.STABLE_DIFFUSION_SDXL_BASE},
# config.model,
# local_files_only=False,
# )
# Create a timestamped directory for all outputs.
out_dir = os.path.join(config.base_output_dir, f"{time.time()}")
ckpt_dir = os.path.join(out_dir, "checkpoints")
os.makedirs(ckpt_dir)
accelerator = initialize_accelerator(
out_dir, config.gradient_accumulation_steps, config.mixed_precision, config.report_to
)
logger = initialize_logging(os.path.basename(__file__), accelerator)
# Set the accelerate seed.
if config.seed is not None:
set_seed(config.seed)
# Log the accelerator configuration from every process to help with debugging.
logger.info(accelerator.state, main_process_only=False)
logger.info("Starting Training.")
logger.info(f"Configuration:\n{json.dumps(config.dict(), indent=2, default=str)}")
logger.info(f"Output dir: '{out_dir}'")
# Write the configuration to disk.
with open(os.path.join(out_dir, "config.json"), "w") as f:
json.dump(config.dict(), f, indent=2, default=str)
weight_dtype = get_dtype_from_str(config.weight_dtype)
logger.info("Loading models.")
tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, unet = load_models_sdxl(
logger=logger,
model_name_or_path=config.model,
hf_variant=config.hf_variant,
vae_model=config.vae_model,
dtype=weight_dtype,
)
if config.xformers:
import_xformers()
# TODO(ryand): There is a known issue if xformers is enabled when training in mixed precision where xformers
# will fail because Q, K, V have different dtypes.
unet.enable_xformers_memory_efficient_attention()
vae.enable_xformers_memory_efficient_attention()
# Prepare text encoder output cache.
# text_encoder_output_cache_dir_name = None
if config.cache_text_encoder_outputs:
raise NotImplementedError("Caching text encoder outputs is not yet supported.")
else:
text_encoder_1.to(accelerator.device, dtype=weight_dtype)
text_encoder_2.to(accelerator.device, dtype=weight_dtype)
# Prepare VAE output cache.
vae_output_cache_dir_name = None
if config.cache_vae_outputs:
raise NotImplementedError("Caching VAE outputs is not yet supported.")
else:
vae.to(accelerator.device, dtype=weight_dtype)
unet.to(accelerator.device, dtype=weight_dtype)
# Add LoRA layers to the models being trained.
trainable_param_groups = []
all_trainable_models: set[torch.nn.Module] = set()
def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float) -> peft.PeftModel:
peft_model = peft.get_peft_model(model, lora_config)
peft_model.print_trainable_parameters()
# Populate `trainable_param_groups`, to be passed to the optimizer.
param_group = {"params": list(filter(lambda p: p.requires_grad, peft_model.parameters())), "lr": lr}
trainable_param_groups.append(param_group)
# Populate all_trainable_models.
all_trainable_models.add(peft_model)
peft_model.train()
return peft_model
if config.train_unet:
unet_lora_config = peft.LoraConfig(
r=config.lora_rank_dim,
# TODO(ryand): Diffusers uses lora_alpha=config.lora_rank_dim. Is that preferred?
lora_alpha=1.0,
target_modules=UNET_TARGET_MODULES,
)
unet = inject_lora_layers(unet, unet_lora_config, lr=config.unet_learning_rate)
if config.train_text_encoder:
text_encoder_lora_config = peft.LoraConfig(
r=config.lora_rank_dim,
lora_alpha=1.0,
# init_lora_weights="gaussian",
target_modules=TEXT_ENCODER_TARGET_MODULES,
)
text_encoder_1 = inject_lora_layers(
text_encoder_1, text_encoder_lora_config, lr=config.text_encoder_learning_rate
)
text_encoder_2 = inject_lora_layers(
text_encoder_2, text_encoder_lora_config, lr=config.text_encoder_learning_rate
)
if config.train_ti:
# TODO(ryand): Move this private function to a shared location.
placeholder_tokens, placeholder_token_ids_1, placeholder_token_ids_2 = _initialize_placeholder_tokens(
config=config,
tokenizer_1=tokenizer_1,
tokenizer_2=tokenizer_2,
text_encoder_1=text_encoder_1,
text_encoder_2=text_encoder_2,
logger=logger,
)
logger.info(f"Initialized {len(placeholder_tokens)} placeholder tokens: {placeholder_tokens}.")
# Unfreeze the token embeddings in the text encoders.
text_encoder_1.text_model.embeddings.token_embedding.requires_grad_(True)
text_encoder_2.text_model.embeddings.token_embedding.requires_grad_(True)
all_trainable_models.add(text_encoder_1)
all_trainable_models.add(text_encoder_2)
for te in [text_encoder_1, text_encoder_2]:
param_group = {
"params": te.get_input_embeddings().parameters(),
"lr": config.textual_inversion_learning_rate,
}
trainable_param_groups.append(param_group)
# If mixed_precision is enabled, cast all trainable params to float32.
if config.mixed_precision != "no":
for trainable_model in all_trainable_models:
for param in trainable_model.parameters():
if param.requires_grad:
param.data = param.to(torch.float32)
if config.gradient_checkpointing:
# We want to enable gradient checkpointing in the UNet regardless of whether it is being trained.
unet.enable_gradient_checkpointing()
# unet must be in train() mode for gradient checkpointing to take effect.
# At the time of writing, the unet dropout probabilities default to 0, so putting the unet in train mode does
# not change its forward behavior.
unet.train()
if config.train_text_encoder:
for te in [text_encoder_1, text_encoder_2]:
te.gradient_checkpointing_enable()
# The text encoders must be in train() mode for gradient checkpointing to take effect. This should
# already be the case, since we are training the text_encoders, be we do it explicitly to make it clear
# that this is required.
# At the time of writing, the text encoder dropout probabilities default to 0, so putting the text
# encoders in train mode does not change their forward behavior.
te.train()
# Set requires_grad = True on the first parameters of the text encoders. Without this, the text encoder
# LoRA weights would have 0 gradients, and so would not get trained. Note that the set of
# trainable_param_groups has already been populated - this won't change what gets trained.
te.text_model.embeddings.requires_grad_(True)
optimizer = initialize_optimizer(config.optimizer, trainable_param_groups)
data_loader = build_textual_inversion_sd_dataloader(
config=config.data_loader,
placeholder_token=config.placeholder_token,
batch_size=config.train_batch_size,
use_masks=config.use_masks,
vae_output_cache_dir=vae_output_cache_dir_name,
)
log_aspect_ratio_buckets(logger=logger, batch_sampler=data_loader.batch_sampler)
assert sum([config.max_train_steps is not None, config.max_train_epochs is not None]) == 1
assert sum([config.save_every_n_steps is not None, config.save_every_n_epochs is not None]) == 1
assert sum([config.validate_every_n_steps is not None, config.validate_every_n_epochs is not None]) == 1
# A "step" represents a single weight update operation (i.e. takes into account gradient accumulation steps).
# math.ceil(...) is used in calculating the num_steps_per_epoch, because by default an optimizer step is taken when
# the end of the dataloader is reached, even if gradient_accumulation_steps hasn't been reached.
num_steps_per_epoch = math.ceil(len(data_loader) / config.gradient_accumulation_steps)
num_train_steps = config.max_train_steps or config.max_train_epochs * num_steps_per_epoch
num_train_epochs = math.ceil(num_train_steps / num_steps_per_epoch)
# TODO(ryand): Test in a distributed training environment and more clearly document the rationale for scaling steps
# by the number of processes. This scaling logic was copied from the diffusers example training code, but it appears
# in many places so I don't know where it originated. Internally, accelerate makes one LR scheduler step per process
# (https://github.com/huggingface/accelerate/blame/49cb83a423f2946059117d8bb39b7c8747d29d80/src/accelerate/scheduler.py#L72-L82),
# so the scaling here simply reverses that behaviour.
lr_scheduler: torch.optim.lr_scheduler.LRScheduler = get_scheduler(
config.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=config.lr_warmup_steps * accelerator.num_processes,
num_training_steps=num_train_steps * accelerator.num_processes,
)
prepared_result: tuple[
UNet2DConditionModel,
peft.PeftModel | CLIPTextModel,
peft.PeftModel | CLIPTextModel,
torch.optim.Optimizer,
torch.utils.data.DataLoader,
torch.optim.lr_scheduler.LRScheduler,
] = accelerator.prepare(
unet,
text_encoder_1,
text_encoder_2,
optimizer,
data_loader,
lr_scheduler,
# Disable automatic device placement for text_encoder if the text encoder outputs were cached.
device_placement=[
True,
not config.cache_text_encoder_outputs,
not config.cache_text_encoder_outputs,
True,
True,
True,
],
)
unet, text_encoder_1, text_encoder_2, optimizer, data_loader, lr_scheduler = prepared_result
if accelerator.is_main_process:
accelerator.init_trackers("lora_and_ti_training")
# Tensorboard uses markdown formatting, so we wrap the config json in a code block.
accelerator.log({"configuration": f"```json\n{json.dumps(config.dict(), indent=2, default=str)}\n```\n"})
checkpoint_tracker = CheckpointTracker(
base_dir=ckpt_dir,
prefix="checkpoint",
max_checkpoints=config.max_checkpoints,
)
# Train!
total_batch_size = config.train_batch_size * accelerator.num_processes * config.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num batches = {len(data_loader)}")
logger.info(f" Instantaneous batch size per device = {config.train_batch_size}")
logger.info(f" Gradient accumulation steps = {config.gradient_accumulation_steps}")
logger.info(f" Parallel processes = {accelerator.num_processes}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Total optimization steps = {num_train_steps}")
logger.info(f" Total epochs = {num_train_epochs}")
global_step = 0
first_epoch = 0
completed_epochs = first_epoch
progress_bar = tqdm(
range(global_step, num_train_steps),
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)
progress_bar.set_description("Steps")
ti_train_steps = num_train_steps
if config.ti_train_steps_ratio is not None:
ti_train_steps = math.ceil(num_train_steps * config.ti_train_steps_ratio)
logger.info(f"The TI training pivot point is set at {ti_train_steps} steps.")
# Keep original embeddings as reference.
with torch.no_grad():
orig_embeds_params_1 = accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight.data.clone()
orig_embeds_params_2 = accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight.data.clone()
def save_checkpoint(num_completed_epochs: int, num_completed_steps: int):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
_save_sdxl_lora_and_ti_checkpoint(
config=config,
epoch=num_completed_epochs,
step=num_completed_steps,
unet=unet,
text_encoder_1=text_encoder_1,
text_encoder_2=text_encoder_2,
placeholder_token_ids_1=placeholder_token_ids_1,
placeholder_token_ids_2=placeholder_token_ids_2,
accelerator=accelerator,
logger=logger,
checkpoint_tracker=checkpoint_tracker,
lora_checkpoint_format=config.lora_checkpoint_format,
callbacks=callbacks,
)
accelerator.wait_for_everyone()
def validate(num_completed_epochs: int, num_completed_steps: int):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
generate_validation_images_sdxl(
epoch=num_completed_epochs,
step=num_completed_steps,
out_dir=out_dir,
accelerator=accelerator,
vae=vae,
text_encoder_1=text_encoder_1,
text_encoder_2=text_encoder_2,
tokenizer_1=tokenizer_1,
tokenizer_2=tokenizer_2,
noise_scheduler=noise_scheduler,
unet=unet,
config=config,
logger=logger,
callbacks=callbacks,
)
accelerator.wait_for_everyone()
for epoch in range(first_epoch, num_train_epochs):
# TODO(ryand): Is this necessary?
text_encoder_1.train()
text_encoder_2.train()
train_loss = 0.0
for data_batch_idx, data_batch in enumerate(data_loader):
if global_step == ti_train_steps and config.train_ti:
logger.info("Reached TI training pivot point. Setting TI learning rate to 0.0.")
# TODO(ryand): The TI embeddings continue to be updated slightly by the normalization step in
# restore_original_embeddings(...). The updates should be very small and converge quickly, so this
# should be fine. But, at some point we should tidy this up.
for ti_param_group in optimizer.param_groups[-2:]:
# The TI param groups should be the last two param groups. But, this is pretty brittle, so this
# assertion adds a bit of safety.
assert len(ti_param_group["params"]) == 1
ti_param_group["lr"] = 0.0
with accelerator.accumulate(unet, text_encoder_1, text_encoder_2):
loss = train_forward(
accelerator=accelerator,
data_batch=data_batch,
vae=vae,
noise_scheduler=noise_scheduler,
tokenizer_1=tokenizer_1,
tokenizer_2=tokenizer_2,
text_encoder_1=text_encoder_1,
text_encoder_2=text_encoder_2,
unet=unet,
weight_dtype=weight_dtype,
resolution=config.data_loader.resolution,
use_masks=config.use_masks,
prediction_type=config.prediction_type,
min_snr_gamma=config.min_snr_gamma,
)
# Gather the losses across all processes for logging (if we use distributed training).
# TODO(ryand): Test that this works properly with distributed training.
avg_loss = accelerator.gather(loss.repeat(config.train_batch_size)).mean()
train_loss += avg_loss.item() / config.gradient_accumulation_steps
# Backpropagate.
accelerator.backward(loss)
if accelerator.sync_gradients and config.max_grad_norm is not None:
params_to_clip = itertools.chain.from_iterable([m.parameters() for m in all_trainable_models])
accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Make sure we don't update any embedding weights besides the newly-added token(s).
# TODO(ryand): Should we only do this if accelerator.sync_gradients?
restore_original_embeddings(
tokenizer=tokenizer_1,
placeholder_token_ids=placeholder_token_ids_1,
accelerator=accelerator,
text_encoder=text_encoder_1,
orig_embeds_params=orig_embeds_params_1,
)
restore_original_embeddings(
tokenizer=tokenizer_2,
placeholder_token_ids=placeholder_token_ids_2,
accelerator=accelerator,
text_encoder=text_encoder_2,
orig_embeds_params=orig_embeds_params_2,
)
# Checks if the accelerator has performed an optimization step behind the scenes.
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
completed_epochs = epoch if (data_batch_idx + 1) < len(data_loader) else epoch + 1
log = {"train_loss": train_loss}
lrs = lr_scheduler.get_last_lr()
# Prepare LR names in the same order that their respective param groups were added to the optimizer.
# TODO: Do this at the time that we prepare the param groups?
lr_names = []
if config.train_unet:
lr_names.append("unet")
if config.train_text_encoder:
lr_names.append("text_encoder_1")
lr_names.append("text_encoder_2")
if config.train_ti:
lr_names.append("ti_embeddings_1")
lr_names.append("ti_embeddings_2")
for lr_idx, lr_name in enumerate(lr_names):
log[f"lr/{lr_name}"] = float(lrs[lr_idx])
if config.optimizer.optimizer_type == "Prodigy":
log[f"lr/d*lr/{lr_name}"] = (
optimizer.param_groups[lr_idx]["d"] * optimizer.param_groups[lr_idx]["lr"]
)
accelerator.log(log, step=global_step)
train_loss = 0.0
# global_step represents the *number of completed steps* at this point.
if config.save_every_n_steps is not None and global_step % config.save_every_n_steps == 0:
save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step)
if (
config.validate_every_n_steps is not None
and global_step % config.validate_every_n_steps == 0
and len(config.validation_prompts) > 0
):
validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step)
logs = {
"step_loss": loss.detach().item(),
"lr": lr_scheduler.get_last_lr()[0],
}
progress_bar.set_postfix(**logs)
if global_step >= num_train_steps:
break
# Save a checkpoint every n epochs.
if config.save_every_n_epochs is not None and completed_epochs % config.save_every_n_epochs == 0:
save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step)
# Generate validation images every n epochs.
if (
config.validate_every_n_epochs is not None
and completed_epochs % config.validate_every_n_epochs == 0
and len(config.validation_prompts) > 0
):
validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step)
accelerator.end_training()
================================================
FILE: src/invoke_training/pipelines/stable_diffusion_xl/textual_inversion/__init__.py
================================================
================================================
FILE: src/invoke_training/pipelines/stable_diffusion_xl/textual_inversion/config.py
================================================
from typing import Literal
from pydantic import model_validator
from invoke_training.config.base_pipeline_config import BasePipelineConfig
from invoke_training.config.data.data_loader_config import TextualInversionSDDataLoaderConfig
from invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig, ProdigyOptimizerConfig
class SdxlTextualInversionConfig(BasePipelineConfig):
type: Literal["SDXL_TEXTUAL_INVERSION"] = "SDXL_TEXTUAL_INVERSION"
"""Must be `SDXL_TEXTUAL_INVERSION`. This is what differentiates training pipeline types.
"""
model: str = "stabilityai/stable-diffusion-xl-base-1.0"
"""Name or path of the base model to train. Can be in diffusers format, or a single stable diffusion checkpoint
file. (E.g. 'stabilityai/stable-diffusion-xl-base-1.0', '/path/to/JuggernautXL.safetensors', etc. )
"""
hf_variant: str | None = "fp16"
"""The Hugging Face Hub model variant to use. Only applies if `model` is a Hugging Face Hub model name.
"""
# Helpful discussion for understanding how this works at inference time:
# https://github.com/huggingface/diffusers/pull/3144#discussion_r1172413509
num_vectors: int = 1
"""Note: `num_vectors` can be overridden by `initial_phrase`.
The number of textual inversion embedding vectors that will be used to learn the concept.
Increasing the `num_vectors` enables the model to learn more complex concepts, but has the following drawbacks:
- greater risk of overfitting
- increased size of the resulting output file
- consumes more of the prompt capacity at inference time
Typical values for `num_vectors` are in the range [1, 16].
As a rule of thumb, `num_vectors` can be increased as the size of the dataset increases (without overfitting).
"""
placeholder_token: str
"""The special word to associate the learned embeddings with. Choose a unique token that is unlikely to already
exist in the tokenizer's vocabulary.
"""
initializer_token: str | None = None
"""Note: Exactly one of `initializer_token`, `initial_embedding_file`, or `initial_phrase` should be set.
A vocabulary token to use as an initializer for the placeholder token. It should be a single word that roughly
describes the object or style that you're trying to train on. Must map to a single tokenizer token.
For example, if you are training on a dataset of images of your pet dog, a good choice would be `dog`.
"""
initial_embedding_file: str | None = None
"""Note: Exactly one of `initializer_token`, `initial_embedding_file`, or `initial_phrase` should be set.
Path to an existing TI embedding that will be used to initialize the embedding being trained. The placeholder
token in the file must match the `placeholder_token` field.
Either `initializer_token` or `initial_embedding_file` should be set.
"""
initial_phrase: str | None = None
"""Note: Exactly one of `initializer_token`, `initial_embedding_file`, or `initial_phrase` should be set.
A phrase that will be used to initialize the placeholder token embedding. The phrase will be tokenized, and the
corresponding embeddings will be used to initialize the placeholder tokens. The number of embedding vectors will be
inferred from the length of the tokenized phrase, so keep the phrase short. The consequences of training a large
number of embedding vectors are discussed in the `num_vectors` field documentation.
For example, if you are training on a dataset of images of pokemon, you might use `pokemon sketch white background`.
"""
optimizer: AdamOptimizerConfig | ProdigyOptimizerConfig = AdamOptimizerConfig()
lr_scheduler: Literal[
"linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"
] = "constant"
lr_warmup_steps: int = 0
"""The number of warmup steps in the learning rate scheduler. Only applied to schedulers that support warmup.
See lr_scheduler.
"""
min_snr_gamma: float | None = 5.0
"""Min-SNR weighting for diffusion training was introduced in https://arxiv.org/abs/2303.09556. This strategy
improves the speed of training convergence by adjusting the weight of each sample.
`min_snr_gamma` acts like an an upper bound on the weight of samples with low noise levels.
If `None`, then Min-SNR weighting will not be applied. If enabled, the recommended value is `min_snr_gamma = 5.0`.
"""
cache_vae_outputs: bool = False
"""If True, the VAE will be applied to all of the images in the dataset before starting training and the results
will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the VAE in VRAM), and
speeds up training (don't have to run the VAE encoding step).
This option can only be enabled if all non-deterministic image augmentations are disabled (i.e. `center_crop=True`,
`random_flip=False`, etc.).
"""
enable_cpu_offload_during_validation: bool = False
"""If True, models will be kept in CPU memory and loaded into GPU memory one-by-one while generating validation
images. This reduces VRAM requirements at the cost of slower generation of validation images.
"""
gradient_accumulation_steps: int = 1
"""The number of gradient steps to accumulate before each weight update. This is an alternative to increasing the
`train_batch_size` when training with limited VRAM.
"""
weight_dtype: Literal["float32", "float16", "bfloat16"] = "bfloat16"
"""All weights (trainable and fixed) will be cast to this precision. Lower precision dtypes require less VRAM, and
result in faster training, but are more prone to issues with numerical stability.
Recommendations:
- `"float32"`: Use this mode if you have plenty of VRAM available.
- `"bfloat16"`: Use this mode if you have limited VRAM and a GPU that supports bfloat16.
- `"float16"`: Use this mode if you have limited VRAM and a GPU that does not support bfloat16.
See also [`mixed_precision`][invoke_training.pipelines.stable_diffusion_xl.textual_inversion.config.SdxlTextualInversionConfig.mixed_precision].
""" # noqa: E501
mixed_precision: Literal["no", "fp16", "bf16", "fp8"] = "no"
"""The mixed precision mode to use.
If mixed precision is enabled, then all non-trainable parameters will be cast to the specified `weight_dtype`, and
trainable parameters are kept in float32 precision to avoid issues with numerical stability.
This value is passed to Hugging Face Accelerate. See
[`accelerate.Accelerator.mixed_precision`](https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.mixed_precision)
for more details.
""" # noqa: E501
xformers: bool = False
"""If `True`, use xformers for more efficient attention blocks.
"""
gradient_checkpointing: bool = False
"""Whether or not to use gradient checkpointing to save memory at the expense of a slower backward pass. Enabling
gradient checkpointing slows down training by ~20%.
"""
max_checkpoints: int | None = None
"""The maximum number of checkpoints to keep. New checkpoints will replace earlier checkpoints to stay under this
limit. Note that this limit is applied to 'step' and 'epoch' checkpoints separately.
"""
prediction_type: Literal["epsilon", "v_prediction"] | None = None
"""The prediction type that will be used for training. If `None`, the prediction type will be inferred from the
scheduler.
"""
max_grad_norm: float | None = None
"""Maximum gradient norm for gradient clipping. Set to `None` for no clipping.
"""
validation_prompts: list[str] = []
"""A list of prompts that will be used to generate images throughout training for the purpose of tracking progress.
"""
negative_validation_prompts: list[str] | None = None
"""A list of negative prompts that will be applied when generating validation images. If set, this list should have
the same length as 'validation_prompts'.
"""
num_validation_images_per_prompt: int = 4
"""The number of validation images to generate for each prompt in `validation_prompts`. Careful, validation can
become very slow if this number is too large.
"""
train_batch_size: int = 4
"""The training batch size.
"""
use_masks: bool = False
"""If True, image masks will be applied to weight the loss during training. The dataset must contain masks for this
feature to be used.
"""
data_loader: TextualInversionSDDataLoaderConfig
"""The data configuration.
See
[`TextualInversionSDDataLoaderConfig`][invoke_training.config.data.data_loader_config.TextualInversionSDDataLoaderConfig]
for details.
"""
vae_model: str | None = None
"""The name of the Hugging Face Hub VAE model to train against. If set, this will override the VAE bundled with the
base model (specified by the `model` parameter). This config option is provided for SDXL models, because SDXL 1.0
shipped with a VAE that produces NaNs in fp16 mode, so it is common to replace this VAE with a fixed version.
"""
@model_validator(mode="after")
def check_validation_prompts(self):
if self.negative_validation_prompts is not None and len(self.negative_validation_prompts) != len(
self.validation_prompts
):
raise ValueError(
f"The number of validation_prompts ({len(self.validation_prompts)}) must match the number of "
f"negative_validation_prompts ({len(self.negative_validation_prompts)})."
)
return self
================================================
FILE: src/invoke_training/pipelines/stable_diffusion_xl/textual_inversion/train.py
================================================
import json
import logging
import math
import os
import tempfile
import time
import torch
import torch.utils.data
from accelerate import Accelerator
from accelerate.utils import set_seed
from diffusers.optimization import get_scheduler
from tqdm.auto import tqdm
from transformers import CLIPPreTrainedModel, CLIPTextModel, CLIPTokenizer, PreTrainedTokenizer
from invoke_training._shared.accelerator.accelerator_utils import (
get_dtype_from_str,
initialize_accelerator,
initialize_logging,
)
from invoke_training._shared.checkpoints.checkpoint_tracker import CheckpointTracker
from invoke_training._shared.checkpoints.serialization import save_state_dict
from invoke_training._shared.data.data_loaders.textual_inversion_sd_dataloader import (
build_textual_inversion_sd_dataloader,
)
from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import log_aspect_ratio_buckets
from invoke_training._shared.optimizer.optimizer_utils import initialize_optimizer
from invoke_training._shared.stable_diffusion.model_loading_utils import load_models_sdxl
from invoke_training._shared.stable_diffusion.textual_inversion import (
initialize_placeholder_tokens_from_initial_phrase,
initialize_placeholder_tokens_from_initializer_token,
restore_original_embeddings,
)
from invoke_training._shared.stable_diffusion.validation import generate_validation_images_sdxl
from invoke_training._shared.utils.import_xformers import import_xformers
from invoke_training.pipelines.callbacks import ModelCheckpoint, ModelType, PipelineCallbacks, TrainingCheckpoint
from invoke_training.pipelines.stable_diffusion_xl.lora.train import cache_vae_outputs, train_forward
from invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config import (
SdxlLoraAndTextualInversionConfig,
)
from invoke_training.pipelines.stable_diffusion_xl.textual_inversion.config import SdxlTextualInversionConfig
def _save_ti_embeddings(
epoch: int,
step: int,
text_encoder_1: CLIPTextModel,
text_encoder_2: CLIPTextModel,
placeholder_token_ids_1: list[int],
placeholder_token_ids_2: list[int],
accelerator: Accelerator,
logger: logging.Logger,
checkpoint_tracker: CheckpointTracker,
callbacks: list[PipelineCallbacks] | None,
):
"""Save a Textual Inversion SDXL checkpoint. Old checkpoints are deleted if necessary to respect the
checkpoint_tracker limits.
"""
# Prune checkpoints and get new checkpoint path.
num_pruned = checkpoint_tracker.prune(1)
if num_pruned > 0:
logger.info(f"Pruned {num_pruned} checkpoint(s).")
save_path = checkpoint_tracker.get_path(epoch=epoch, step=step)
learned_embeds_1 = (
accelerator.unwrap_model(text_encoder_1)
.get_input_embeddings()
.weight[min(placeholder_token_ids_1) : max(placeholder_token_ids_1) + 1]
)
learned_embeds_2 = (
accelerator.unwrap_model(text_encoder_2)
.get_input_embeddings()
.weight[min(placeholder_token_ids_2) : max(placeholder_token_ids_2) + 1]
)
learned_embeds_dict = {
"clip_l": learned_embeds_1.detach().cpu().to(dtype=torch.float32),
"clip_g": learned_embeds_2.detach().cpu().to(dtype=torch.float32),
}
save_state_dict(learned_embeds_dict, save_path)
if callbacks is not None:
for cb in callbacks:
cb.on_save_checkpoint(
TrainingCheckpoint(
models=[ModelCheckpoint(file_path=save_path, model_type=ModelType.SDXL_TEXTUAL_INVERSION)],
epoch=epoch,
step=step,
)
)
def _initialize_placeholder_tokens(
config: SdxlTextualInversionConfig | SdxlLoraAndTextualInversionConfig,
tokenizer_1: CLIPTokenizer,
tokenizer_2: CLIPTokenizer,
text_encoder_1: PreTrainedTokenizer,
text_encoder_2: PreTrainedTokenizer,
logger: logging.Logger,
) -> tuple[list[str], list[int], list[int]]:
"""Prepare the tokenizers and text_encoders for TI training.
- Add the placeholder tokens to the tokenizers.
- Add new token embeddings to the text_encoders for each of the placeholder tokens.
- Initialize the new token embeddings from either an existing token, or an initial TI embedding file.
"""
if (
sum(
[
getattr(config, "initializer_token", None) is not None,
getattr(config, "initial_embedding_file", None) is not None,
getattr(config, "initial_phrase", None) is not None,
]
)
!= 1
):
raise ValueError(
"Exactly one of 'initializer_token', 'initial_embedding_file', or 'initial_phrase' should be set."
)
if getattr(config, "initializer_token", None) is not None:
placeholder_tokens_1, placeholder_token_ids_1 = initialize_placeholder_tokens_from_initializer_token(
tokenizer=tokenizer_1,
text_encoder=text_encoder_1,
initializer_token=config.initializer_token,
placeholder_token=config.placeholder_token,
num_vectors=config.num_vectors,
logger=logger,
)
placeholder_tokens_2, placeholder_token_ids_2 = initialize_placeholder_tokens_from_initializer_token(
tokenizer=tokenizer_2,
text_encoder=text_encoder_2,
initializer_token=config.initializer_token,
placeholder_token=config.placeholder_token,
num_vectors=config.num_vectors,
logger=logger,
)
elif getattr(config, "initial_embedding_file", None) is not None:
# TODO(ryan)
raise NotImplementedError("Initializing from an initial embedding is not yet supported for SDXL.")
elif getattr(config, "initial_phrase", None) is not None:
placeholder_tokens_1, placeholder_token_ids_1 = initialize_placeholder_tokens_from_initial_phrase(
tokenizer=tokenizer_1,
text_encoder=text_encoder_1,
initial_phrase=config.initial_phrase,
placeholder_token=config.placeholder_token,
)
placeholder_tokens_2, placeholder_token_ids_2 = initialize_placeholder_tokens_from_initial_phrase(
tokenizer=tokenizer_2,
text_encoder=text_encoder_2,
initial_phrase=config.initial_phrase,
placeholder_token=config.placeholder_token,
)
else:
raise ValueError(
"Exactly one of 'initializer_token', 'initial_embedding_file', or 'initial_phrase' should be set."
)
assert placeholder_tokens_1 == placeholder_tokens_2
return placeholder_tokens_1, placeholder_token_ids_1, placeholder_token_ids_2
def train(config: SdxlTextualInversionConfig, callbacks: list[PipelineCallbacks] | None = None): # noqa: C901
# Create a timestamped directory for all outputs.
out_dir = os.path.join(config.base_output_dir, f"{time.time()}")
ckpt_dir = os.path.join(out_dir, "checkpoints")
os.makedirs(ckpt_dir)
accelerator = initialize_accelerator(
out_dir, config.gradient_accumulation_steps, config.mixed_precision, config.report_to
)
logger = initialize_logging(os.path.basename(__file__), accelerator)
# Set the accelerate seed.
if config.seed is not None:
set_seed(config.seed)
# Log the accelerator configuration from every process to help with debugging.
logger.info(accelerator.state, main_process_only=False)
logger.info("Starting Training.")
logger.info(f"Configuration:\n{json.dumps(config.dict(), indent=2, default=str)}")
logger.info(f"Output dir: '{out_dir}'")
# Write the configuration to disk.
with open(os.path.join(out_dir, "config.json"), "w") as f:
json.dump(config.dict(), f, indent=2, default=str)
weight_dtype = get_dtype_from_str(config.weight_dtype)
logger.info("Loading models.")
tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, unet = load_models_sdxl(
logger=logger,
model_name_or_path=config.model,
hf_variant=config.hf_variant,
vae_model=config.vae_model,
dtype=weight_dtype,
)
placeholder_tokens, placeholder_token_ids_1, placeholder_token_ids_2 = _initialize_placeholder_tokens(
config=config,
tokenizer_1=tokenizer_1,
tokenizer_2=tokenizer_2,
text_encoder_1=text_encoder_1,
text_encoder_2=text_encoder_2,
logger=logger,
)
logger.info(f"Initialized {len(placeholder_tokens)} placeholder tokens: {placeholder_tokens}.")
# All parameters of the VAE, UNet, and text encoder are currently frozen. Just unfreeze the token embeddings in the
# text encoders.
text_encoder_1.text_model.embeddings.token_embedding.requires_grad_(True)
text_encoder_2.text_model.embeddings.token_embedding.requires_grad_(True)
if config.gradient_checkpointing:
# We want to enable gradient checkpointing in the UNet regardless of whether it is being trained.
unet.enable_gradient_checkpointing()
# unet must be in train() mode for gradient checkpointing to take effect.
# At the time of writing, the unet dropout probabilities default to 0, so putting the unet in train mode does
# not change its forward behavior.
unet.train()
for te in [text_encoder_1, text_encoder_2]:
# The text_encoder will be put in .train() mode later, so we don't need to worry about that here.
# Note: There are some weird interactions gradient checkpointing and requires_grad_() when training a
# text_encoder LoRA. If this code ever gets copied elsewhere, make sure to take a look at how this is
# handled in other training pipelines.
te.gradient_checkpointing_enable()
if config.xformers:
import_xformers()
unet.enable_xformers_memory_efficient_attention()
vae.enable_xformers_memory_efficient_attention()
# Prepare VAE output cache.
vae_output_cache_dir_name = None
if config.cache_vae_outputs:
if config.data_loader.random_flip:
raise ValueError("'cache_vae_outputs' cannot be True if 'random_flip' is True.")
if not config.data_loader.center_crop:
raise ValueError("'cache_vae_outputs' cannot be True if 'center_crop' is False.")
# We use a temporary directory for the cache. The directory will automatically be cleaned up when
# tmp_vae_output_cache_dir is destroyed.
tmp_vae_output_cache_dir = tempfile.TemporaryDirectory()
vae_output_cache_dir_name = tmp_vae_output_cache_dir.name
if accelerator.is_local_main_process:
# Only the main process should to populate the cache.
logger.info(f"Generating VAE output cache ('{vae_output_cache_dir_name}').")
vae.to(accelerator.device, dtype=weight_dtype)
data_loader = build_textual_inversion_sd_dataloader(
config=config.data_loader,
placeholder_token=config.placeholder_token,
batch_size=config.train_batch_size,
use_masks=config.use_masks,
shuffle=False,
)
cache_vae_outputs(vae_output_cache_dir_name, data_loader, vae)
# Move the VAE back to the CPU, because it is not needed for training.
vae.to("cpu")
accelerator.wait_for_everyone()
else:
vae.to(accelerator.device, dtype=weight_dtype)
unet.to(accelerator.device, dtype=weight_dtype)
text_encoder_1.to(accelerator.device, dtype=weight_dtype)
text_encoder_2.to(accelerator.device, dtype=weight_dtype)
# Initialize the optimizer to only optimize the token embeddings.
trainable_param_groups = [
{"params": text_encoder_1.get_input_embeddings().parameters()},
{"params": text_encoder_2.get_input_embeddings().parameters()},
]
optimizer = initialize_optimizer(config.optimizer, trainable_param_groups)
trainable_models = torch.nn.ModuleDict({"text_encoder_1": text_encoder_1, "text_encoder_2": text_encoder_2})
data_loader = build_textual_inversion_sd_dataloader(
config=config.data_loader,
placeholder_token=config.placeholder_token,
batch_size=config.train_batch_size,
use_masks=config.use_masks,
vae_output_cache_dir=vae_output_cache_dir_name,
)
log_aspect_ratio_buckets(logger=logger, batch_sampler=data_loader.batch_sampler)
assert sum([config.max_train_steps is not None, config.max_train_epochs is not None]) == 1
assert sum([config.save_every_n_steps is not None, config.save_every_n_epochs is not None]) == 1
assert sum([config.validate_every_n_steps is not None, config.validate_every_n_epochs is not None]) == 1
# A "step" represents a single weight update operation (i.e. takes into account gradient accumulation steps).
# math.ceil(...) is used in calculating the num_steps_per_epoch, because by default an optimizer step is taken when
# the end of the dataloader is reached, even if gradient_accumulation_steps hasn't been reached.
num_steps_per_epoch = math.ceil(len(data_loader) / config.gradient_accumulation_steps)
num_train_steps = config.max_train_steps or config.max_train_epochs * num_steps_per_epoch
num_train_epochs = math.ceil(num_train_steps / num_steps_per_epoch)
# TODO(ryand): Test in a distributed training environment and more clearly document the rationale for scaling steps
# by the number of processes. This scaling logic was copied from the diffusers example training code, but it appears
# in many places so I don't know where it originated. Internally, accelerate makes one LR scheduler step per process
# (https://github.com/huggingface/accelerate/blame/49cb83a423f2946059117d8bb39b7c8747d29d80/src/accelerate/scheduler.py#L72-L82),
# so the scaling here simply reverses that behaviour.
lr_scheduler: torch.optim.lr_scheduler.LRScheduler = get_scheduler(
config.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=config.lr_warmup_steps * accelerator.num_processes,
num_training_steps=num_train_steps * accelerator.num_processes,
)
prepared_result: tuple[
CLIPPreTrainedModel,
CLIPPreTrainedModel,
torch.optim.Optimizer,
torch.utils.data.DataLoader,
torch.optim.lr_scheduler.LRScheduler,
] = accelerator.prepare(text_encoder_1, text_encoder_2, optimizer, data_loader, lr_scheduler)
text_encoder_1, text_encoder_2, optimizer, data_loader, lr_scheduler = prepared_result
if accelerator.is_main_process:
accelerator.init_trackers("textual_inversion_training")
# Tensorboard uses markdown formatting, so we wrap the config json in a code block.
accelerator.log({"configuration": f"```json\n{json.dumps(config.dict(), indent=2, default=str)}\n```\n"})
checkpoint_tracker = CheckpointTracker(
base_dir=ckpt_dir,
prefix="checkpoint",
extension=".safetensors",
max_checkpoints=config.max_checkpoints,
)
# Train!
total_batch_size = config.train_batch_size * accelerator.num_processes * config.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num batches = {len(data_loader)}")
logger.info(f" Instantaneous batch size per device = {config.train_batch_size}")
logger.info(f" Gradient accumulation steps = {config.gradient_accumulation_steps}")
logger.info(f" Parallel processes = {accelerator.num_processes}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Total optimization steps = {num_train_steps}")
logger.info(f" Total epochs = {num_train_epochs}")
global_step = 0
first_epoch = 0
completed_epochs = 0
progress_bar = tqdm(
range(global_step, num_train_steps),
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)
progress_bar.set_description("Steps")
# Keep original embeddings as reference.
with torch.no_grad():
orig_embeds_params_1 = accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight.data.clone()
orig_embeds_params_2 = accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight.data.clone()
def save_checkpoint(num_completed_epochs: int, num_completed_steps: int):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
_save_ti_embeddings(
epoch=num_completed_epochs,
step=num_completed_steps,
text_encoder_1=text_encoder_1,
text_encoder_2=text_encoder_2,
placeholder_token_ids_1=placeholder_token_ids_1,
placeholder_token_ids_2=placeholder_token_ids_2,
accelerator=accelerator,
logger=logger,
checkpoint_tracker=checkpoint_tracker,
callbacks=callbacks,
)
accelerator.wait_for_everyone()
def validate(num_completed_epochs: int, num_completed_steps: int):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
generate_validation_images_sdxl(
epoch=num_completed_epochs,
step=num_completed_steps,
out_dir=out_dir,
accelerator=accelerator,
vae=vae,
text_encoder_1=text_encoder_1,
text_encoder_2=text_encoder_2,
tokenizer_1=tokenizer_1,
tokenizer_2=tokenizer_2,
noise_scheduler=noise_scheduler,
unet=unet,
config=config,
logger=logger,
callbacks=callbacks,
)
accelerator.wait_for_everyone()
for epoch in range(first_epoch, num_train_epochs):
text_encoder_1.train()
text_encoder_2.train()
train_loss = 0.0
for data_batch_idx, data_batch in enumerate(data_loader):
with accelerator.accumulate(trainable_models):
loss = train_forward(
accelerator=accelerator,
data_batch=data_batch,
vae=vae,
noise_scheduler=noise_scheduler,
tokenizer_1=tokenizer_1,
tokenizer_2=tokenizer_2,
text_encoder_1=text_encoder_1,
text_encoder_2=text_encoder_2,
unet=unet,
weight_dtype=weight_dtype,
resolution=config.data_loader.resolution,
use_masks=config.use_masks,
prediction_type=config.prediction_type,
min_snr_gamma=config.min_snr_gamma,
)
# Gather the losses across all processes for logging (if we use distributed training).
# TODO(ryand): Test that this works properly with distributed training.
avg_loss = accelerator.gather(loss.repeat(config.train_batch_size)).mean()
train_loss += avg_loss.item() / config.gradient_accumulation_steps
# Backpropagate.
accelerator.backward(loss)
if accelerator.sync_gradients and config.max_grad_norm is not None:
# TODO(ryand): I copied this from another pipeline. Should probably just clip the trainable params.
params_to_clip = trainable_models.parameters()
accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# Make sure we don't update any embedding weights besides the newly-added token(s).
# TODO(ryand): Should we only do this if accelerator.sync_gradients?
restore_original_embeddings(
tokenizer=tokenizer_1,
placeholder_token_ids=placeholder_token_ids_1,
accelerator=accelerator,
text_encoder=text_encoder_1,
orig_embeds_params=orig_embeds_params_1,
)
restore_original_embeddings(
tokenizer=tokenizer_2,
placeholder_token_ids=placeholder_token_ids_2,
accelerator=accelerator,
text_encoder=text_encoder_2,
orig_embeds_params=orig_embeds_params_2,
)
# Checks if the accelerator has performed an optimization step behind the scenes.
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
completed_epochs = epoch if (data_batch_idx + 1) < len(data_loader) else epoch + 1
log = {"train_loss": train_loss, "lr": lr_scheduler.get_last_lr()[0]}
if config.optimizer.optimizer_type == "Prodigy":
# TODO(ryand): Test Prodigy logging.
log["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
accelerator.log(log, step=global_step)
train_loss = 0.0
# global_step represents the *number of completed steps* at this point.
if config.save_every_n_steps is not None and global_step % config.save_every_n_steps == 0:
save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step)
if (
config.validate_every_n_steps is not None
and global_step % config.validate_every_n_steps == 0
and len(config.validation_prompts) > 0
):
validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step)
logs = {
"step_loss": loss.detach().item(),
"lr": lr_scheduler.get_last_lr()[0],
}
progress_bar.set_postfix(**logs)
if global_step >= num_train_steps:
break
# Save a checkpoint every n epochs.
if config.save_every_n_epochs is not None and completed_epochs % config.save_every_n_epochs == 0:
save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step)
# Generate validation images every n epochs.
if (
config.validate_every_n_epochs is not None
and completed_epochs % config.validate_every_n_epochs == 0
and len(config.validation_prompts) > 0
):
validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step)
accelerator.end_training()
================================================
FILE: src/invoke_training/sample_configs/_experimental/sd_dpo_lora_pickapic_1x24gb.yaml
================================================
# Training mode: Direct Preference Optimization LoRA Training
# Dataset: A small subset of the pickapic_v2 dataset.
# Base model: SD 1.5
# GPU: 1 x 24GB
#
# Training takes ~2 hours on a single RTX 4090.
type: SD_DIRECT_PREFERENCE_OPTIMIZATION_LORA
seed: 1
base_output_dir: output/dpo
optimizer:
optimizer_type: AdamW
learning_rate: 1e-4
weight_decay: 1e-2
lr_warmup_steps: 200
lr_scheduler: cosine
data_loader:
type: IMAGE_PAIR_PREFERENCE_SD_DATA_LOADER
dataset:
type: HF_HUB_IMAGE_PAIR_PREFERENCE_DATASET
resolution: 512
# General
model: runwayml/stable-diffusion-v1-5
gradient_accumulation_steps: 2
weight_dtype: float16
mixed_precision: fp16
gradient_checkpointing: True
max_train_steps: 5000
save_every_n_epochs: 1
save_every_n_steps: null
max_checkpoints: 100
validation_prompts:
- A monk in an orange robe by a round window in a spaceship in dramatic lighting
- A galaxy-colored figurine is floating over the sea at sunset, photorealistic
- Concept art of a mythical sky alligator with wings, nature documentary
validate_every_n_epochs: 1
train_batch_size: 4
num_validation_images_per_prompt: 1
================================================
FILE: src/invoke_training/sample_configs/_experimental/sd_dpo_lora_refinement_pokemon_1x24gb.yaml
================================================
# Training mode: Direct Preference Optimization LoRA Training
# Base model: SD 1.5
# GPU: 1 x 24GB
type: SD_DIRECT_PREFERENCE_OPTIMIZATION_LORA
seed: 1
base_output_dir: output/dpo
optimizer:
optimizer_type: AdamW
learning_rate: 1e-4
weight_decay: 1e-2
lr_warmup_steps: 500
lr_scheduler: cosine
data_loader:
type: IMAGE_PAIR_PREFERENCE_SD_DATA_LOADER
dataset:
type: IMAGE_PAIR_PREFERENCE_DATASET
dataset_dir: output/pokemon_pairs
resolution: 512
dataloader_num_workers: 4
# General
model: runwayml/stable-diffusion-v1-5
initial_lora: output/sd_lora_pokemon/1704824279.2765746/checkpoint_epoch-00000003
gradient_accumulation_steps: 2
weight_dtype: float16
mixed_precision: fp16
gradient_checkpointing: True
max_train_steps: 5000
save_every_n_epochs: 10
save_every_n_steps: null
max_checkpoints: 100
validation_prompts:
- A cute yoda pokemon creature.
- A cute astronaut pokemon creature.
validate_every_n_epochs: 10
train_batch_size: 4
num_validation_images_per_prompt: 2
================================================
FILE: src/invoke_training/sample_configs/flux_lora_1x40gb.yaml
================================================
# Training mode: LoRA
# Base model: Flux.1-dev
# Dataset: Bruce the Gnome
# GPU: 1 x 40GB
type: FLUX_LORA
seed: 1
base_output_dir: output/experiments/bruce_the_gnome/flux_lora
optimizer:
optimizer_type: AdamW
learning_rate: 1e-4
lr_warmup_steps: 1
lr_scheduler: constant
transformer_learning_rate: 4e-4
text_encoder_learning_rate: 4e-4
train_text_encoder: False
data_loader:
type: IMAGE_CAPTION_FLUX_DATA_LOADER
dataset:
type: IMAGE_CAPTION_JSONL_DATASET
# Update the jsonl_path field to point to the metadata.jsonl file of the downloaded dataset.
jsonl_path: sample_data/bruce_the_gnome/data.jsonl
resolution: 768
aspect_ratio_buckets:
target_resolution: 768
start_dim: 384
end_dim: 1536
divisible_by: 128
caption_prefix: "bruce the gnome"
dataloader_num_workers: 4
# General
model: black-forest-labs/FLUX.1-dev
gradient_accumulation_steps: 1
weight_dtype: bfloat16
gradient_checkpointing: True
max_train_steps: 350
save_every_n_steps: 50
validate_every_n_steps: 50
max_checkpoints: 10
validation_prompts:
- A stuffed gnome at the beach with a pina colada in its hand.
- A stuffed gnome reading a book in a cozy library.
- A stuffed gnome sitting in a garden surrounded by colorful flowers and butterflies.
train_batch_size: 4
num_validation_images_per_prompt: 3
================================================
FILE: src/invoke_training/sample_configs/sd_lora_baroque_1x8gb.yaml
================================================
# Training mode: Finetuning with LoRA
# Base model: SD 1.5
# Dataset: https://huggingface.co/datasets/InvokeAI/nga-baroque
# GPU: 1 x 8GB
# Instructions:
# 1. Download the dataset from https://huggingface.co/datasets/InvokeAI/nga-baroque.
# 2. Update the `jsonl_path` field in the `data_loader` section to point to the `metadata.jsonl` file of the downloaded
# dataset.
# Notes:
# This config file has been optimized for the primary goal of achieving reasonable results *quickly* for demo purposes.
type: SD_LORA
seed: 1
base_output_dir: output/baroque/sd_lora
optimizer:
optimizer_type: Prodigy
learning_rate: 1.0
weight_decay: 0.01
use_bias_correction: True
safeguard_warmup: True
data_loader:
type: IMAGE_CAPTION_SD_DATA_LOADER
dataset:
type: IMAGE_CAPTION_JSONL_DATASET
# Update the jsonl_path field to point to the metadata.jsonl file of the downloaded dataset.
jsonl_path: data/nga-baroque/metadata.jsonl
resolution: 512
aspect_ratio_buckets:
target_resolution: 512
start_dim: 256
end_dim: 768
divisible_by: 64
caption_prefix: "A baroque painting of"
dataloader_num_workers: 4
# General
model: runwayml/stable-diffusion-v1-5
gradient_accumulation_steps: 1
weight_dtype: bfloat16
gradient_checkpointing: True
max_train_epochs: 15
save_every_n_epochs: 1
validate_every_n_epochs: 1
max_checkpoints: 5
validation_prompts:
- A baroque painting of a woman carrying a basket of fruit.
- A baroque painting of a cute Yoda creature.
train_batch_size: 4
num_validation_images_per_prompt: 3
================================================
FILE: src/invoke_training/sample_configs/sd_textual_inversion_gnome_1x8gb.yaml
================================================
# Training mode: Textual Inversion
# Base model: SD v1
# GPU: 1 x 24GB
type: SD_TEXTUAL_INVERSION
seed: 1
base_output_dir: output/sd_ti_bruce_the_gnome
optimizer:
optimizer_type: AdamW
learning_rate: 4e-3
lr_warmup_steps: 200
lr_scheduler: cosine
data_loader:
type: TEXTUAL_INVERSION_SD_DATA_LOADER
dataset:
type: IMAGE_DIR_DATASET
dataset_dir: "sample_data/bruce_the_gnome"
keep_in_memory: True
caption_preset: object
resolution: 512
center_crop: True
random_flip: False
shuffle_caption_delimiter: null
aspect_ratio_buckets:
target_resolution: 512
start_dim: 256
end_dim: 768
divisible_by: 64
dataloader_num_workers: 4
# General
model: runwayml/stable-diffusion-v1-5
num_vectors: 4
placeholder_token: "bruce_the_gnome"
initializer_token: "gnome"
cache_vae_outputs: False
gradient_accumulation_steps: 1
weight_dtype: bfloat16
gradient_checkpointing: True
max_train_steps: 2000
save_every_n_steps: 200
validate_every_n_steps: 200
max_checkpoints: 20
validation_prompts:
- A photo of bruce_the_gnome at the beach
- A photo of bruce_the_gnome reading a book
train_batch_size: 1
num_validation_images_per_prompt: 3
================================================
FILE: src/invoke_training/sample_configs/sdxl_finetune_baroque_1x24gb.yaml
================================================
# Training mode: Full Finetuning
# Base model: SDXL
# Dataset: https://huggingface.co/datasets/InvokeAI/nga-baroque
# GPU: 1 x 24GB
# Instructions:
# 1. Download the dataset from https://huggingface.co/datasets/InvokeAI/nga-baroque.
# 2. Update the `jsonl_path` field in the `data_loader` section to point to the `metadata.jsonl` file of the downloaded
# dataset.
type: SDXL_FINETUNE
seed: 1
base_output_dir: output/baroque/sdxl_finetune
optimizer:
optimizer_type: AdamW
learning_rate: 5e-5
weight_decay: 1e-3
use_8bit: True
lr_scheduler: constant_with_warmup
lr_warmup_steps: 500
data_loader:
type: IMAGE_CAPTION_SD_DATA_LOADER
dataset:
type: IMAGE_CAPTION_JSONL_DATASET
# Update the jsonl_path field to point to the metadata.jsonl file of the downloaded dataset.
jsonl_path: data/nga-baroque/metadata.jsonl
resolution: 1024
aspect_ratio_buckets:
target_resolution: 1024
start_dim: 512
end_dim: 1536
divisible_by: 128
caption_prefix: "A baroque style painting,"
# General
model: stabilityai/stable-diffusion-xl-base-1.0
save_checkpoint_format: trained_only_diffusers
# vae_model: madebyollin/sdxl-vae-fp16-fix
save_dtype: float16
gradient_accumulation_steps: 1
weight_dtype: bfloat16
gradient_checkpointing: True
cache_vae_outputs: True
cache_text_encoder_outputs: True
max_train_epochs: 50
save_every_n_epochs: 3
validate_every_n_epochs: 3
# We save a max of 1 checkpoint for demo purposes, because the checkpoints take up a lot of disk space.
max_checkpoints: 1
validation_prompts:
- A baroque style painting of a woman carrying a basket of fruit.
- A baroque style painting of a cute Yoda creature.
train_batch_size: 4
num_validation_images_per_prompt: 3
================================================
FILE: src/invoke_training/sample_configs/sdxl_finetune_robocats_1x24gb.yaml
================================================
# Training mode: Full finetune
# Base model: SDXL
# Dataset: Robocats
# GPU: 1 x 24GB
type: SDXL_FINETUNE
seed: 1
base_output_dir: output/robocats/sdxl_finetune
optimizer:
optimizer_type: AdamW
learning_rate: 2e-5
use_8bit: True
lr_scheduler: constant_with_warmup
lr_warmup_steps: 200
data_loader:
type: IMAGE_CAPTION_SD_DATA_LOADER
dataset:
type: IMAGE_CAPTION_JSONL_DATASET
# Update the jsonl_path field to point to the metadata.jsonl file of the downloaded dataset.
jsonl_path: /home/ryan/data/robocats/data.jsonl
resolution: 1024
aspect_ratio_buckets:
target_resolution: 1024
start_dim: 512
end_dim: 1536
divisible_by: 128
caption_prefix: "In the robocat style,"
# General
model: stabilityai/stable-diffusion-xl-base-1.0
save_checkpoint_format: trained_only_diffusers
# vae_model: madebyollin/sdxl-vae-fp16-fix
save_dtype: float16
gradient_accumulation_steps: 1
weight_dtype: bfloat16
gradient_checkpointing: True
cache_vae_outputs: True
cache_text_encoder_outputs: True
max_train_steps: 2000
validate_every_n_steps: 200
save_every_n_steps: 2000
# We save a max of 1 checkpoint for demo purposes, because the checkpoints take up a lot of disk space.
max_checkpoints: 1
validation_prompts:
- In the robocat style, a robotic lion in the jungle.
- In the robocat style, a hamburger and fries.
train_batch_size: 4
num_validation_images_per_prompt: 3
================================================
FILE: src/invoke_training/sample_configs/sdxl_lora_and_ti_gnome_1x24gb.yaml
================================================
# Training mode: Finetuning with LoRA and Textual Inversion
# Base model: SDXL 1.0
# GPU: 1 x 24GB
type: SDXL_LORA_AND_TEXTUAL_INVERSION
seed: 1
base_output_dir: output/sdxl_lora_and_ti_bruce_the_gnome
optimizer:
optimizer_type: AdamW
learning_rate: 2e-3
lr_warmup_steps: 200
lr_scheduler: constant
data_loader:
type: TEXTUAL_INVERSION_SD_DATA_LOADER
dataset:
type: IMAGE_DIR_DATASET
dataset_dir: "sample_data/bruce_the_gnome"
keep_in_memory: True
caption_preset: object
resolution: 1024
center_crop: True
random_flip: False
shuffle_caption_delimiter: null
dataloader_num_workers: 4
# General
model: stabilityai/stable-diffusion-xl-base-1.0
vae_model: madebyollin/sdxl-vae-fp16-fix
num_vectors: 2
placeholder_token: "bruce_the_gnome"
initializer_token: "gnome"
cache_vae_outputs: False
gradient_accumulation_steps: 1
weight_dtype: bfloat16
gradient_checkpointing: True
max_train_steps: 2000
save_every_n_steps: 200
validate_every_n_steps: 200
max_checkpoints: 50
validation_prompts:
- A photo of bruce_the_gnome at the beach
- A photo of bruce_the_gnome reading a book
train_batch_size: 1
num_validation_images_per_prompt: 3
================================================
FILE: src/invoke_training/sample_configs/sdxl_lora_baroque_1x24gb.yaml
================================================
# Training mode: Finetuning with LoRA
# Base model: SDXL 1.0
# Dataset: https://huggingface.co/datasets/InvokeAI/nga-baroque
# GPU: 1 x 24GB
# Instructions:
# 1. Download the dataset from https://huggingface.co/datasets/InvokeAI/nga-baroque.
# 2. Update the `jsonl_path` field in the `data_loader` section to point to the `metadata.jsonl` file of the downloaded
# dataset.
# Notes:
# This config file has been optimized for the primary goal of achieving reasonable results *quickly* for demo
# purposes.
type: SDXL_LORA
seed: 1
base_output_dir: output/baroque/sdxl_lora
optimizer:
optimizer_type: AdamW
learning_rate: 1e-3
data_loader:
type: IMAGE_CAPTION_SD_DATA_LOADER
dataset:
type: IMAGE_CAPTION_JSONL_DATASET
# Update the jsonl_path field to point to the metadata.jsonl file of the downloaded dataset.
jsonl_path: data/nga-baroque/metadata_masks.jsonl
resolution: 1024
aspect_ratio_buckets:
target_resolution: 1024
start_dim: 512
end_dim: 1536
divisible_by: 128
caption_prefix: "A baroque painting of"
# General
model: stabilityai/stable-diffusion-xl-base-1.0
# vae_model: madebyollin/sdxl-vae-fp16-fix
gradient_accumulation_steps: 1
weight_dtype: bfloat16
gradient_checkpointing: True
cache_vae_outputs: True
max_train_epochs: 16
save_every_n_epochs: 2
validate_every_n_epochs: 2
use_masks: False
max_checkpoints: 5
validation_prompts:
- A baroque painting of a woman carrying a basket of fruit.
- A baroque painting of a cute Yoda creature.
train_batch_size: 4
num_validation_images_per_prompt: 3
================================================
FILE: src/invoke_training/sample_configs/sdxl_lora_baroque_1x8gb.yaml
================================================
# Training mode: Finetuning with LoRA
# Base model: SDXL 1.0
# Dataset: https://huggingface.co/datasets/InvokeAI/nga-baroque
# GPU: 1 x 8GB
# Instructions:
# 1. Download the dataset from https://huggingface.co/datasets/InvokeAI/nga-baroque.
# 2. Update the `jsonl_path` field in the `data_loader` section to point to the `metadata.jsonl` file of the downloaded
# dataset.
# Notes:
# This config file has been optimized for 2 primary goals:
# - Minimize VRAM usage so that an SDXL model can be trained with only 8GB of VRAM.
# - Achieve reasonable results *quickly* for demo purposes.
type: SDXL_LORA
seed: 1
base_output_dir: output/baroque/sdxl_lora
optimizer:
optimizer_type: Prodigy
learning_rate: 1.0
weight_decay: 0.01
use_bias_correction: True
safeguard_warmup: True
data_loader:
type: IMAGE_CAPTION_SD_DATA_LOADER
dataset:
type: IMAGE_CAPTION_JSONL_DATASET
# Update the jsonl_path field to point to the metadata.jsonl file of the downloaded dataset.
jsonl_path: data/nga-baroque/metadata.jsonl
# TODO: More optimizations are needed to train at full 1024x1024 resolution with 8GB VRAM.
resolution: 512
# aspect_ratio_buckets:
# target_resolution: 1024
# start_dim: 512
# end_dim: 1536
# divisible_by: 128
caption_prefix: "A baroque painting of"
# General
model: stabilityai/stable-diffusion-xl-base-1.0
vae_model: madebyollin/sdxl-vae-fp16-fix
train_text_encoder: False
cache_text_encoder_outputs: True
enable_cpu_offload_during_validation: True
gradient_accumulation_steps: 4
weight_dtype: bfloat16
gradient_checkpointing: True
max_train_epochs: 6
save_every_n_epochs: 1
validate_every_n_epochs: 1
max_checkpoints: 5
validation_prompts:
- A baroque painting of a woman carrying a basket of fruit.
- A baroque painting of a cute Yoda creature.
train_batch_size: 1
num_validation_images_per_prompt: 3
================================================
FILE: src/invoke_training/sample_configs/sdxl_lora_masks_gnome_1x24gb.yaml
================================================
# Training mode: LoRA with masks
# Base model: SDXL 1.0
# Dataset: Bruce the Gnome
# GPU: 1 x 24GB
type: SDXL_LORA
seed: 1
base_output_dir: output/bruce/sdxl_lora_masks
optimizer:
optimizer_type: AdamW
learning_rate: 7e-5
lr_scheduler: constant_with_warmup
lr_warmup_steps: 50
data_loader:
type: IMAGE_CAPTION_SD_DATA_LOADER
dataset:
type: IMAGE_CAPTION_JSONL_DATASET
jsonl_path: sample_data/bruce_the_gnome/data_masks.jsonl
resolution: 1024
aspect_ratio_buckets:
target_resolution: 1024
start_dim: 512
end_dim: 1536
divisible_by: 128
# General
model: stabilityai/stable-diffusion-xl-base-1.0
# vae_model: madebyollin/sdxl-vae-fp16-fix
gradient_accumulation_steps: 1
weight_dtype: bfloat16
gradient_checkpointing: True
cache_vae_outputs: True
max_train_steps: 500
save_every_n_steps: 50
validate_every_n_steps: 50
use_masks: True
max_checkpoints: 5
validation_prompts:
- A stuffed gnome at the beach with a pina colada in its hand.
- A stuffed gnome reading a book in a cozy library.
train_batch_size: 4
num_validation_images_per_prompt: 3
================================================
FILE: src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml
================================================
# Training mode: Textual Inversion
# Base model: SDXL
# GPU: 1 x 24GB
type: SDXL_TEXTUAL_INVERSION
seed: 1
base_output_dir: output/bruce/sdxl_ti
optimizer:
optimizer_type: AdamW
learning_rate: 2e-3
lr_warmup_steps: 200
lr_scheduler: cosine
data_loader:
type: TEXTUAL_INVERSION_SD_DATA_LOADER
dataset:
type: IMAGE_DIR_DATASET
dataset_dir: "sample_data/bruce_the_gnome"
keep_in_memory: True
caption_preset: object
resolution: 1024
center_crop: True
random_flip: False
shuffle_caption_delimiter: null
dataloader_num_workers: 4
# General
model: stabilityai/stable-diffusion-xl-base-1.0
vae_model: madebyollin/sdxl-vae-fp16-fix
num_vectors: 4
placeholder_token: "bruce_the_gnome"
initializer_token: "gnome"
cache_vae_outputs: False
gradient_accumulation_steps: 1
weight_dtype: bfloat16
gradient_checkpointing: True
max_train_steps: 2000
save_every_n_steps: 200
validate_every_n_steps: 200
max_checkpoints: 20
validation_prompts:
- A photo of bruce_the_gnome at the beach
- A photo of bruce_the_gnome reading a book
train_batch_size: 1
num_validation_images_per_prompt: 3
================================================
FILE: src/invoke_training/sample_configs/sdxl_textual_inversion_masks_gnome_1x24gb.yaml
================================================
# Training mode: Textual Inversion with Masks
# Base model: SDXL
# GPU: 1 x 24GB
type: SDXL_TEXTUAL_INVERSION
seed: 1
base_output_dir: output/bruce/sdxl_ti_masks
optimizer:
optimizer_type: AdamW
learning_rate: 5e-4
lr_scheduler: constant_with_warmup
lr_warmup_steps: 50
data_loader:
type: TEXTUAL_INVERSION_SD_DATA_LOADER
dataset:
type: IMAGE_CAPTION_JSONL_DATASET
jsonl_path: sample_data/bruce_the_gnome/data_masks.jsonl
keep_in_memory: True
caption_preset: object
resolution: 1024
center_crop: True
random_flip: False
shuffle_caption_delimiter: null
# General
model: stabilityai/stable-diffusion-xl-base-1.0
num_vectors: 16
placeholder_token: "bruce_the_gnome"
initializer_token: "gnome"
cache_vae_outputs: False
gradient_accumulation_steps: 1
weight_dtype: bfloat16
gradient_checkpointing: True
max_train_steps: 500
save_every_n_steps: 50
validate_every_n_steps: 50
use_masks: True
max_checkpoints: 10
validation_prompts:
- A photo of bruce_the_gnome at the beach with a pina colada in its hand.
- A photo of bruce_the_gnome reading a book in a cozy library.
train_batch_size: 4
num_validation_images_per_prompt: 3
================================================
FILE: src/invoke_training/scripts/__init__.py
================================================
================================================
FILE: src/invoke_training/scripts/_experimental/auto_caption/auto_caption_images.py
================================================
import argparse
import json
from pathlib import Path
import torch
import torch.utils.data
from PIL import Image
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from invoke_training.scripts.utils.image_dir_dataset import ImageDirDataset, list_collate_fn
def select_device_and_dtype(force_cpu: bool = False) -> tuple[torch.device, torch.dtype]:
if force_cpu:
return torch.device("cpu"), torch.float32
if torch.cuda.is_available():
return torch.device("cuda"), torch.float16
return torch.device("cpu"), torch.float32
def process_images(images: list[Image.Image], prompt: str, moondream, tokenizer) -> list[str]:
# image_embeds = moondream.encode_image(image).to(device=device)
# answer = moondream.answer_question(image_embeds, prompt, tokenizer)
answers = moondream.batch_answer(
images=images,
prompts=[prompt] * len(images),
tokenizer=tokenizer,
)
return answers
def main(
prompt: str,
use_cpu: bool,
batch_size: int,
output_path: str,
dataset: torch.utils.data.Dataset,
):
device, dtype = select_device_and_dtype(use_cpu)
print(f"Using device: {device}")
print(f"Using dtype: {dtype}")
# Check that the output file does not already exist before spending time generating captions.
out_path = Path(output_path)
if out_path.exists():
raise FileExistsError(f"Output file already exists: {out_path}")
# Load the model.
model_id = "vikhyatk/moondream2"
model_revision = "2024-04-02"
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=model_revision)
# TODO(ryand): Warn about security implications of trust_remote_code=True.
moondream_model: AutoModelForCausalLM = AutoModelForCausalLM.from_pretrained(
model_id, trust_remote_code=True, revision=model_revision
).to(device=device, dtype=dtype)
moondream_model.eval()
data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=list_collate_fn, batch_size=batch_size, drop_last=False
)
results = []
for image_batch in tqdm(data_loader):
image_paths = image_batch["image_path"]
answers = process_images(image_batch["image"], prompt, moondream_model, tokenizer)
for image_path, answer in zip(image_paths, answers, strict=True):
results.append({"image": image_path, "text": answer})
# Check that the output file does not exist immediately before writing to it.
if out_path.exists():
raise FileExistsError(f"Output file already exists: {out_path}")
with open(out_path, "w") as outfile:
for entry in results:
json.dump(entry, outfile)
outfile.write("\n")
print("Output saved to output.jsonl.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the moondream captioning model on a directory of images.")
parser.add_argument("--dir", type=str, required=True, help="Directory containing images.")
parser.add_argument(
"--prompt",
type=str,
default="Describe this image in 20 words or less.",
help="(Optional) Prompt for the model.",
)
parser.add_argument(
"--cpu",
action="store_true",
default=False,
help="Force use of CPU instead of GPU. If not set, a GPU will be used if available.",
)
parser.add_argument(
"--batch-size",
type=int,
default=4,
help="Batch size for processing images. To maximize speed, set this to the largest value that fits in GPU "
"memory.",
)
parser.add_argument(
"--output",
type=str,
default="output.jsonl",
help="(Optional) Path to the output file. Default is 'output.jsonl'.",
)
args = parser.parse_args()
# Prepare the dataset.
dataset = ImageDirDataset(args.dir)
print(f"Found {len(dataset)} images in '{args.dir}'.")
main(args.prompt, args.cpu, args.batch_size, args.output, dataset)
================================================
FILE: src/invoke_training/scripts/_experimental/masks/clipseg.py
================================================
import torch
from PIL import Image
from transformers import AutoProcessor, CLIPSegForImageSegmentation, CLIPSegProcessor
def load_clipseg_model() -> tuple[CLIPSegProcessor, CLIPSegForImageSegmentation]:
# Load the model.
clipseg_processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
clipseg_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
return clipseg_processor, clipseg_model
def run_clipseg(
images: list[Image.Image],
prompt: str,
clipseg_processor,
clipseg_model,
clipseg_temp: float,
device: torch.device,
) -> list[Image.Image]:
"""Run ClipSeg on a list of images.
Args:
clipseg_temp (float): Temperature applied to the CLIPSeg logits. Higher values cause the mask to be 'smoother'
and include more of the background. Recommended range: 0.5 to 1.0.
"""
orig_image_sizes = [img.size for img in images]
prompts = [prompt] * len(images)
# TODO(ryand): Should we run the same image with and without the prompt to normalize for any bias in the model?
inputs = clipseg_processor(text=prompts, images=images, padding=True, return_tensors="pt")
# Move inputs and clipseg_model to the correct device and dtype.
inputs = {k: v.to(device=device) for k, v in inputs.items()}
clipseg_model = clipseg_model.to(device=device)
outputs = clipseg_model(**inputs)
logits = outputs.logits
if logits.ndim == 2:
# The model squeezes the batch dimension if it's 1, so we need to unsqueeze it.
logits = logits.unsqueeze(0)
probs = torch.nn.functional.sigmoid(logits / clipseg_temp)
# Normalize each mask to 0-255. Note that each mask is normalized independently.
probs = 255 * probs / probs.amax(dim=(1, 2), keepdim=True)
# Make mask greyscale.
masks: list[Image.Image] = []
for prob, orig_size in zip(probs, orig_image_sizes, strict=True):
mask = Image.fromarray(prob.cpu().numpy()).convert("L")
mask = mask.resize(orig_size)
masks.append(mask)
return masks
def select_device() -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
================================================
FILE: src/invoke_training/scripts/_experimental/masks/generate_masks.py
================================================
import argparse
from pathlib import Path
import torch
import torch.utils.data
from tqdm import tqdm
from invoke_training.scripts._experimental.masks.clipseg import load_clipseg_model, run_clipseg, select_device
from invoke_training.scripts.utils.image_dir_dataset import ImageDirDataset, list_collate_fn
@torch.no_grad()
def generate_masks(image_dir: str, prompt: str, clipseg_temp: float, batch_size: int):
"""Generate masks for a directory of images.
Args:
image_dir (str): The directory containing images.
prompt (str): A short description of the thing you want to mask. E.g. 'a cat'.
clipseg_temp (float): Temperature applied to the CLIPSeg logits. Higher values cause the mask to be 'smoother'.
and include more of the background. Recommended range: 0.5 to 1.0.
batch_size (int): Batch size to use when processing images. Larger batch sizes may be faster but require more.
"""
device = select_device()
clipseg_processor, clipseg_model = load_clipseg_model()
# Prepare the dataloader.
dataset = ImageDirDataset(image_dir)
print(f"Found {len(dataset)} images in '{image_dir}'.")
data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=list_collate_fn, batch_size=batch_size, drop_last=False
)
# Process each image.
for batch in tqdm(data_loader):
masks = run_clipseg(
images=batch["image"],
prompt=prompt,
clipseg_processor=clipseg_processor,
clipseg_model=clipseg_model,
clipseg_temp=clipseg_temp,
device=device,
)
for image_path, mask in zip(batch["image_path"], masks, strict=True):
image_path = Path(image_path)
out_path = image_path.parent / "masks" / (image_path.stem + ".png")
out_path.parent.mkdir(exist_ok=True, parents=True)
mask.save(out_path)
print(f"Saved mask to: {out_path}")
def main():
parser = argparse.ArgumentParser(description="Generate masks for a directory of images.")
parser.add_argument("--dir", type=str, required=True, help="Directory containing images.")
parser.add_argument(
"--prompt",
required=True,
type=str,
help="A short description of the thing you want to mask. E.g. 'a cat'.",
)
parser.add_argument(
"--clipseg-temp",
type=float,
default=1.0,
help="Temperature applied to the CLIPSeg logits. Higher values cause the mask to be 'smoother' and include "
"more of the background. Recommended range: 0.5 to 1.0.",
)
parser.add_argument(
"--batch-size",
type=int,
default=4,
help="Batch size to use when processing images. Larger batch sizes may be faster but require more memory.",
)
args = parser.parse_args()
generate_masks(image_dir=args.dir, prompt=args.prompt, clipseg_temp=args.clipseg_temp, batch_size=args.batch_size)
if __name__ == "__main__":
main()
================================================
FILE: src/invoke_training/scripts/_experimental/masks/generate_masks_for_jsonl_dataset.py
================================================
import argparse
from pathlib import Path
import torch
import torch.utils.data
from tqdm import tqdm
from invoke_training._shared.data.datasets.image_caption_jsonl_dataset import (
MASK_COLUMN_DEFAULT,
ImageCaptionJsonlDataset,
)
from invoke_training._shared.utils.jsonl import load_jsonl, save_jsonl
from invoke_training.scripts._experimental.masks.clipseg import load_clipseg_model, run_clipseg, select_device
def collate_fn(examples):
"""A collate_fn that combines images into a list rather than stacking into a tensor."""
return {
"id": [example["id"] for example in examples],
"image": [example["image"] for example in examples],
}
def validate_out_json_path(out_json_path: str | Path):
out_json_path = Path(out_json_path)
if out_json_path.exists():
raise FileExistsError(f"Output jsonl file '{out_json_path}' already exists.")
if not out_json_path.suffix == ".jsonl":
raise ValueError(f"Output jsonl file '{out_json_path}' must have a .jsonl extension.")
@torch.no_grad()
def generate_masks(
in_jsonl_path: str,
out_jsonl_path: str,
image_column: str,
caption_column: str,
prompt: str,
clipseg_temp: float,
batch_size: int,
):
"""Generate masks for a .jsonl dataset."""
# Load the .jsonl dataset.
dataset = ImageCaptionJsonlDataset(
jsonl_path=in_jsonl_path, image_column=image_column, caption_column=caption_column
)
print(f"Loaded dataset from '{in_jsonl_path}' with {len(dataset)} images.")
data_loader = torch.utils.data.DataLoader(dataset, collate_fn=collate_fn, batch_size=batch_size, drop_last=False)
# We also need the raw jsonl data.
jsonl_data = load_jsonl(in_jsonl_path)
# Prepare output locations.
out_jsonl_path = Path(out_jsonl_path)
validate_out_json_path(out_jsonl_path)
out_masks_dir = out_jsonl_path.parent / "masks"
out_masks_dir.mkdir(exist_ok=False, parents=True)
clipseg_processor, clipseg_model = load_clipseg_model()
device = select_device()
# Process each image.
for batch in tqdm(data_loader):
masks = run_clipseg(
images=batch["image"],
prompt=prompt,
clipseg_processor=clipseg_processor,
clipseg_model=clipseg_model,
clipseg_temp=clipseg_temp,
device=device,
)
for id, mask in zip(batch["id"], masks, strict=True):
orig_image_path = Path(jsonl_data[int(id)][image_column])
out_mask_path: Path = out_masks_dir / (orig_image_path.stem + ".png")
mask.save(out_mask_path)
print(f"Saved mask to: {out_mask_path}")
# Infer whether the mask path should be relative or absolute based on the image path.
if orig_image_path.is_absolute():
jsonl_data[int(id)][MASK_COLUMN_DEFAULT] = str(out_mask_path.resolve())
else:
jsonl_data[int(id)][MASK_COLUMN_DEFAULT] = str(out_mask_path.relative_to(out_jsonl_path.parent))
# Save the modified jsonl data.
validate_out_json_path(out_jsonl_path)
save_jsonl(jsonl_data, out_jsonl_path)
print(f"Saved modified jsonl data to: {out_jsonl_path}")
def main():
parser = argparse.ArgumentParser(description="Generate masks for a jsonl dataset.")
parser.add_argument("--in-jsonl", type=str, required=True, help="Path to the dataset .jsonl file.")
parser.add_argument(
"--out-jsonl",
type=str,
required=True,
help="Path to save the modified .jsonl file to. A masks/ directory will be created in the same directory as "
"the .jsonl file to store the masks. The choice of whether to use relative or absolute paths for the masks is "
"inferred from the image paths.",
)
parser.add_argument(
"--image-column",
type=str,
default="image",
help="The name of the column containing image paths in the input .jsonl file.",
)
parser.add_argument(
"--caption-column",
type=str,
default="text",
help="The name of the column containing captions in the input .jsonl file.",
)
parser.add_argument(
"--prompt",
required=True,
type=str,
help="A short description of the thing you want to mask. E.g. 'a cat'.",
)
parser.add_argument(
"--clipseg-temp",
type=float,
default=1.0,
help="Temperature applied to the CLIPSeg logits. Higher values cause the mask to be 'smoother' and include "
"more of the background. Recommended range: 0.5 to 1.0.",
)
parser.add_argument(
"--batch-size",
type=int,
default=4,
help="Batch size to use when processing images. Larger batch sizes may be faster but require more memory.",
)
args = parser.parse_args()
generate_masks(
in_jsonl_path=args.in_jsonl,
out_jsonl_path=args.out_jsonl,
image_column=args.image_column,
caption_column=args.caption_column,
prompt=args.prompt,
clipseg_temp=args.clipseg_temp,
batch_size=args.batch_size,
)
if __name__ == "__main__":
main()
================================================
FILE: src/invoke_training/scripts/_experimental/rank_images.py
================================================
import argparse
import os
import time
from pathlib import Path
from typing import Literal
import gradio as gr
import yaml
from pydantic import TypeAdapter
from invoke_training._shared.data.datasets.image_pair_preference_dataset import ImagePairPreferenceDataset
from invoke_training.config.pipeline_config import PipelineConfig
def parse_args():
parser = argparse.ArgumentParser(description="Choose preferences from image pairs.")
parser.add_argument(
"-c",
"--cfg-file",
type=Path,
required=True,
help="Path to the YAML training config file. The internal dataset config will be used.",
)
return parser.parse_args()
def clip(val, min_val, max_val):
return max(min(val, max_val), min_val)
def main():
args = parse_args()
# Load YAML config file.
with open(args.cfg_file, "r") as f:
cfg = yaml.safe_load(f)
pipeline_adapter: TypeAdapter[PipelineConfig] = TypeAdapter(PipelineConfig)
train_config = pipeline_adapter.validate_python(cfg)
dataset_config = train_config.data_loader.dataset
assert dataset_config.type == "IMAGE_PAIR_PREFERENCE_DATASET"
metadata = ImagePairPreferenceDataset.load_metadata(dataset_config.dataset_dir)
print(f"Launching UI to rank image pairs in '{dataset_config.dataset_dir}'.")
def get_img_path(index: int, image_id: Literal["image_0", "image_1"]):
return os.path.join(dataset_config.dataset_dir, metadata[index][image_id])
def get_state(index: int):
img_0 = get_img_path(index, "image_0")
img_1 = get_img_path(index, "image_1")
prefer_0 = metadata[index]["prefer_0"]
prefer_1 = metadata[index]["prefer_1"]
caption = metadata[index]["prompt"]
return [index, img_0, img_1, prefer_0, prefer_1, caption]
def go_to_index(index: int):
new_index = clip(index, 0, len(metadata) - 1)
return get_state(new_index)
def mark_prefer_0(index: int):
metadata[index]["prefer_0"] = True
metadata[index]["prefer_1"] = False
# Step to next example.
return go_to_index(index + 1)
def mark_prefer_1(index: int):
metadata[index]["prefer_0"] = False
metadata[index]["prefer_1"] = True
# Step to next example.
return go_to_index(index + 1)
def save_metadata():
timestamp = str(time.time()).replace(".", "_")
metadata_file = f"metadata-{timestamp}.jsonl"
metadata_path = ImagePairPreferenceDataset.save_metadata(
metadata=metadata, dataset_dir=dataset_config.dataset_dir, metadata_file=metadata_file
)
print(f"Saved metadata to '{metadata_path}'.")
with gr.Blocks() as demo:
index = gr.Number(value=-1, precision=0)
with gr.Row():
img_0 = gr.Image(type="filepath", label="Image 0", interactive=False)
img_1 = gr.Image(type="filepath", label="Image 1", interactive=False)
caption = gr.Textbox(interactive=False, show_label=False)
with gr.Row():
prefer_0 = gr.Checkbox(label="Prefer 0", interactive=False)
prefer_1 = gr.Checkbox(label="Prefer 1", interactive=False)
with gr.Row():
mark_prefer_0_button = gr.Button("Prefer 0")
mark_prefer_1_button = gr.Button("Prefer 1")
save_metadata_button = gr.Button("Save Metadata")
index.change(go_to_index, inputs=[index], outputs=[index, img_0, img_1, prefer_0, prefer_1, caption])
mark_prefer_0_button.click(
mark_prefer_0, inputs=[index], outputs=[index, img_0, img_1, prefer_0, prefer_1, caption]
)
mark_prefer_1_button.click(
mark_prefer_1, inputs=[index], outputs=[index, img_0, img_1, prefer_0, prefer_1, caption]
)
save_metadata_button.click(save_metadata)
demo.launch()
if __name__ == "__main__":
main()
================================================
FILE: src/invoke_training/scripts/convert_sd_lora_to_kohya_format.py
================================================
import argparse
from pathlib import Path
import torch
from invoke_training._shared.stable_diffusion.lora_checkpoint_utils import (
convert_sd_peft_checkpoint_to_kohya_state_dict,
)
def parse_args():
parser = argparse.ArgumentParser(
description="Convert a Stable Diffusion LoRA checkpoint in PEFT format to kohya format."
)
parser.add_argument(
"--src-ckpt-dir",
type=str,
required=True,
help="Path to the source checkpoint directory.",
)
parser.add_argument(
"--dst-ckpt-file",
type=str,
required=True,
help="Path to the destination Kohya checkpoint file.",
)
parser.add_argument(
"--dtype",
type=str,
default="fp16",
help="The precision to save the kohya state dict in. One of ['fp16', 'fp32'].",
)
return parser.parse_args()
def main():
args = parse_args()
in_checkpoint_dir = Path(args.src_ckpt_dir)
out_checkpoint_file = Path(args.dst_ckpt_file)
if args.dtype == "fp32":
dtype = torch.float32
elif args.dtype == "fp16":
dtype = torch.float16
else:
raise ValueError(f"Unsupported --dtype = '{args.dtype}'.")
convert_sd_peft_checkpoint_to_kohya_state_dict(
in_checkpoint_dir=in_checkpoint_dir, out_checkpoint_file=out_checkpoint_file, dtype=dtype
)
print(f"Saved kohya checkpoint to '{out_checkpoint_file}'.")
if __name__ == "__main__":
main()
================================================
FILE: src/invoke_training/scripts/invoke_generate_images.py
================================================
import argparse
from pathlib import Path
from invoke_training._shared.stable_diffusion.model_loading_utils import PipelineVersionEnum
from invoke_training._shared.tools.generate_images import generate_images
def parse_args():
parser = argparse.ArgumentParser(
description="Generate a dataset of images from a single prompt. (Typically used to generate prior "
"preservation/regularization datasets.)"
)
parser.add_argument(
"-o",
"--out-dir",
type=str,
required=True,
help="Path to the directory where the images will be stored.",
)
parser.add_argument(
"-m",
"--model",
type=str,
required=True,
help="Name or path of the diffusers model to generate images with. Can be in diffusers format, or a single "
"stable diffusion checkpoint file. (E.g. 'runwayml/stable-diffusion-v1-5', "
"'stabilityai/stable-diffusion-xl-base-1.0', '/path/to/realisticVisionV51_v51VAE.safetensors', etc. )",
)
parser.add_argument(
"-v",
"--variant",
type=str,
required=False,
default=None,
help="The Hugging Face Hub model variant to use. Only applies if `--model` is a Hugging Face Hub model name.",
)
parser.add_argument(
"-l",
"--lora",
type=str,
nargs="*",
help="LoRA models to apply to the base model. The LoRA weight can optionally be provided after a colon "
"separator. E.g. `-l path/to/lora.bin:0.5 -l path/to/lora_2.safetensors`. ",
)
parser.add_argument(
"--ti",
type=str,
nargs="*",
help="Paths(s) to Textual Inversion embeddings to apply to the base model.",
)
parser.add_argument(
"--sd-version",
type=str,
required=True,
help="The Stable Diffusion version. One of: ['SD', 'SDXL'].",
)
# One of --prompt or --prompt-file.
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("-p", "--prompt", type=str, help="The prompt to use for image generation.")
group.add_argument("--prompt-file", type=str, help="A file containing prompts. One per line.")
parser.add_argument(
"--set-size", type=int, default=1, help="The number of images generated in each 'set' for a given prompt."
)
parser.add_argument("--num-sets", type=int, default=1, help="The number of 'sets' to generate for each prompt.")
parser.add_argument(
"--height",
type=int,
required=True,
help="The height of the generated images in pixels.",
)
parser.add_argument(
"--width",
type=int,
required=True,
help="The width of the generated images in pixels.",
)
parser.add_argument(
"-s",
"--seed",
type=int,
default=0,
help="Seed for repeatability.",
)
parser.add_argument(
"--enable-cpu-offload",
default=False,
action="store_true",
help="If True, models will be loaded onto the GPU one by one to conserve VRAM.",
)
return parser.parse_args()
def parse_lora_args(lora_args: list[str] | None) -> list[tuple[Path, int]]:
loras: list[tuple[Path, int]] = []
lora_args = lora_args or []
for lora in lora_args:
lora_split = lora.split(":")
if len(lora_split) == 1:
# If weight is not specified, assume 1.0.
loras.append((Path(lora_split[0]), 1.0))
elif len(lora_split) == 2:
loras.append((Path(lora_split[0]), float(lora_split[1])))
else:
raise ValueError(f"Invalid lora argument syntax: '{lora}'.")
return loras
def parse_prompt_file(prompt_file: str) -> list[str]:
with open(prompt_file) as f:
prompts = f.readlines()
return [p.strip() for p in prompts]
def main():
args = parse_args()
loras = parse_lora_args(args.lora)
if args.prompt:
prompts = [args.prompt]
else:
prompts = parse_prompt_file(args.prompt_file)
print(f"Generating {args.num_sets} sets of {args.set_size} images for {len(prompts)} prompts in '{args.out_dir}'.")
generate_images(
out_dir=args.out_dir,
model=args.model,
hf_variant=args.variant,
pipeline_version=PipelineVersionEnum(args.sd_version),
prompts=prompts,
set_size=args.set_size,
num_sets=args.num_sets,
height=args.height,
width=args.width,
loras=loras,
ti_embeddings=args.ti,
seed=args.seed,
enable_cpu_offload=args.enable_cpu_offload,
)
if __name__ == "__main__":
main()
================================================
FILE: src/invoke_training/scripts/invoke_train.py
================================================
import argparse
from pathlib import Path
import yaml
from pydantic import TypeAdapter
from invoke_training.config.pipeline_config import PipelineConfig
from invoke_training.pipelines.invoke_train import train
def parse_args():
parser = argparse.ArgumentParser(description="Run a training pipeline.")
parser.add_argument(
"-c",
"--cfg-file",
type=Path,
required=True,
help="Path to the YAML training config file.",
)
return parser.parse_args()
def main():
args = parse_args()
# Load YAML config file.
with open(args.cfg_file, "r") as f:
cfg = yaml.safe_load(f)
pipeline_adapter: TypeAdapter[PipelineConfig] = TypeAdapter(PipelineConfig)
train_config = pipeline_adapter.validate_python(cfg)
train(train_config)
if __name__ == "__main__":
main()
================================================
FILE: src/invoke_training/scripts/invoke_train_ui.py
================================================
import argparse
import uvicorn
from invoke_training.ui.app import build_app
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--host",
default="127.0.0.1",
help="The server host. Set `--host 0.0.0.0` to make the app available on your network.",
)
parser.add_argument("--port", default=8000, type=int, help="The server port.")
args = parser.parse_args()
app = build_app()
uvicorn.run(
app,
host=args.host,
port=args.port,
)
if __name__ == "__main__":
main()
================================================
FILE: src/invoke_training/scripts/invoke_visualize_data_loading.py
================================================
import argparse
import os
import time
from pathlib import Path
import numpy as np
import torch
import yaml
from PIL import Image
from pydantic import TypeAdapter
from torch.utils.data import DataLoader
from invoke_training._shared.data.data_loaders.dreambooth_sd_dataloader import build_dreambooth_sd_dataloader
from invoke_training._shared.data.data_loaders.image_caption_sd_dataloader import (
build_image_caption_sd_dataloader,
)
from invoke_training._shared.data.data_loaders.image_pair_preference_sd_dataloader import (
build_image_pair_preference_sd_dataloader,
)
from invoke_training._shared.data.data_loaders.textual_inversion_sd_dataloader import (
build_textual_inversion_sd_dataloader,
)
from invoke_training.config.pipeline_config import PipelineConfig
def save_image(torch_image: torch.Tensor, out_path: Path):
"""Save a torch image to disk.
Args:
torch_image (torch.Tensor): Shape=(C, H, W). Pixel values are expected to be normalized in the range
[-1.0, 1.0].
out_path (Path): The output path.
"""
np_image = torch_image.clone().detach().cpu().numpy()
# Convert back to range [0, 1.0].
np_image = np_image * 0.5 + 0.5
# Convert back to range [0, 255].
np_image *= 255
# Move channel axis from first dimension to last dimension.
np_image = np.moveaxis(np_image, 0, -1)
# Cast to np.uint8.
np_image = np_image.astype(np.uint8)
Image.fromarray(np_image).save(out_path)
def parse_args():
parser = argparse.ArgumentParser(description="Visualize data loading from a pipeline config.")
parser.add_argument(
"-c",
"--cfg-file",
type=Path,
required=True,
help="Path to the YAML training config file.",
)
return parser.parse_args()
def visualize(data_loader: DataLoader):
out_dir = Path(f"out_{str(time.time()).replace('.', '-')}/")
os.makedirs(out_dir)
for batch_idx, batch in enumerate(data_loader):
print(f"Batch {batch_idx}:")
batch_path = out_dir / f"batch_{batch_idx}"
batch_path.mkdir()
saved_images = []
for k, v in batch.items():
if isinstance(v, torch.Tensor):
print(f"{k}: Tensor.shape={v.shape}")
if len(v.shape) == 4 and v.shape[1] == 3:
# This is likely a batch of RGB images, so we save them to disk.
for i in range(v.shape[0]):
out_path = batch_path / f"{k}_{i}.png"
save_image(v[i, ...], out_path)
saved_images.append(out_path)
else:
print(f"{k}: {v}")
for saved_image in saved_images:
print(f"Saved image to '{saved_image}'.")
_ = input("\n\nPress Enter to continue to next batch...\n")
def main():
args = parse_args()
# Load YAML config file.
with open(args.cfg_file, "r") as f:
cfg = yaml.safe_load(f)
pipeline_adapter: TypeAdapter[PipelineConfig] = TypeAdapter(PipelineConfig)
train_config = pipeline_adapter.validate_python(cfg)
data_loader_config = train_config.data_loader
if data_loader_config.type == "IMAGE_CAPTION_SD_DATA_LOADER":
data_loader = build_image_caption_sd_dataloader(
config=data_loader_config,
batch_size=train_config.train_batch_size,
shuffle=False,
)
elif data_loader_config.type == "TEXTUAL_INVERSION_SD_DATA_LOADER":
data_loader = build_textual_inversion_sd_dataloader(
config=data_loader_config,
placeholder_token="",
batch_size=train_config.train_batch_size,
shuffle=False,
)
elif data_loader_config.type == "DREAMBOOTH_SD_DATA_LOADER":
data_loader = build_dreambooth_sd_dataloader(
config=data_loader_config,
batch_size=train_config.train_batch_size,
shuffle=False,
sequential_batching=False,
)
elif data_loader_config.type == "IMAGE_PAIR_PREFERENCE_SD_DATA_LOADER":
data_loader = build_image_pair_preference_sd_dataloader(
config=data_loader_config,
batch_size=train_config.train_batch_size,
shuffle=False,
)
else:
raise ValueError(f"Unexpected data loader type: '{data_loader_config.type}'.")
visualize(data_loader)
if __name__ == "__main__":
main()
================================================
FILE: src/invoke_training/scripts/utils/image_dir_dataset.py
================================================
import os
import typing
import torch
from PIL import Image
class ImageDirDataset(torch.utils.data.Dataset):
"""A simple dataset that loads images from a directory."""
def __init__(
self,
dataset_dir: str,
image_extensions: typing.Optional[list[str]] = None,
):
super().__init__()
if image_extensions is None:
image_extensions = [".png", ".jpg", ".jpeg"]
image_extensions = [ext.lower() for ext in image_extensions]
# Determine the list of image paths to include in the dataset.
self._image_paths: list[str] = []
for image_file in os.listdir(dataset_dir):
image_path = os.path.join(dataset_dir, image_file)
if os.path.isfile(image_path) and os.path.splitext(image_path)[1].lower() in image_extensions:
self._image_paths.append(image_path)
self._image_paths.sort()
def _load_image(self, image_path: str) -> Image.Image:
# We call `convert("RGB")` to drop the alpha channel from RGBA images, or to repeat channels for greyscale
# images.
return Image.open(image_path).convert("RGB")
def __len__(self) -> int:
return len(self._image_paths)
def __getitem__(self, idx: int):
image_path = self._image_paths[idx]
image = self._load_image(image_path)
return {"image_path": self._image_paths[idx], "image": image}
def list_collate_fn(examples):
"""Custom collate_fn that combines images into a list rather than stacking into a tensor. This is what the Moondream
model expects.
"""
return {
"image": [example["image"] for example in examples],
"image_path": [example["image_path"] for example in examples],
}
================================================
FILE: src/invoke_training/ui/__init__.py
================================================
================================================
FILE: src/invoke_training/ui/app.py
================================================
from pathlib import Path
import gradio as gr
from fastapi import FastAPI
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from invoke_training.ui.pages.data_page import DataPage
from invoke_training.ui.pages.training_page import TrainingPage
def build_app():
training_page = TrainingPage()
data_page = DataPage()
app = FastAPI()
@app.get("/")
async def root():
index_path = Path(__file__).parent / "index.html"
return FileResponse(index_path)
app.mount("/assets", StaticFiles(directory=Path(__file__).parent.parent / "assets"), name="assets")
app = gr.mount_gradio_app(app, training_page.app(), "/train", app_kwargs={"favicon_path": "/assets/favicon.png"})
app = gr.mount_gradio_app(app, data_page.app(), "/data", app_kwargs={"favicon_path": "/assets/favicon.png"})
return app
================================================
FILE: src/invoke_training/ui/config_groups/__init__.py
================================================
================================================
FILE: src/invoke_training/ui/config_groups/aspect_ratio_bucket_config_group.py
================================================
from typing import Any
import gradio as gr
from invoke_training.config.data.data_loader_config import AspectRatioBucketConfig
from invoke_training.ui.config_groups.ui_config_element import UIConfigElement
class AspectRatioBucketConfigGroup(UIConfigElement):
def __init__(self):
gr.Markdown(
"Aspect ratio bucket resolutions are generated as follows:\n"
"- Iterate over 'first' dimension values from `start_dim` to `end_dim` in steps of size `divisible_by`.\n"
"- Calculate the 'second' dimension to be as close as possible to the total number of pixels in "
"`target_resolution`, while still being divisible by `divisible_by`."
)
self.enabled = gr.Checkbox(label="Use Aspect Ratio Bucketing", interactive=True)
self.target_resolution = gr.Number(label="target_resolution", interactive=True, precision=0)
self.start_dim = gr.Number(label="start_dimension", interactive=True, precision=0)
self.end_dim = gr.Number(label="end_dimension", interactive=True, precision=0)
self.divisible_by = gr.Number(label="divisible_by", interactive=True, precision=0)
def update_ui_components_with_config_data(
self, config: AspectRatioBucketConfig | None
) -> dict[gr.components.Component, Any]:
enabled = True
if config is None:
enabled = False
# We just construct this config to hold default values.
config = AspectRatioBucketConfig(target_resolution=512, start_dim=256, end_dim=768, divisible_by=64)
update_dict = {
self.enabled: enabled,
self.target_resolution: config.target_resolution,
self.start_dim: config.start_dim,
self.end_dim: config.end_dim,
self.divisible_by: config.divisible_by,
}
return update_dict
def update_config_with_ui_component_data(
self, orig_config: AspectRatioBucketConfig | None, ui_data: dict[gr.components.Component, Any]
) -> AspectRatioBucketConfig | None:
# TODO: Use orig_config?
if not ui_data.pop(self.enabled):
# Pop fields from ui_data so that upstream code knows that the fields were handled.
ui_data.pop(self.target_resolution)
ui_data.pop(self.start_dim)
ui_data.pop(self.end_dim)
ui_data.pop(self.divisible_by)
return None
new_config = AspectRatioBucketConfig(
target_resolution=ui_data.pop(self.target_resolution),
start_dim=ui_data.pop(self.start_dim),
end_dim=ui_data.pop(self.end_dim),
divisible_by=ui_data.pop(self.divisible_by),
)
return new_config
================================================
FILE: src/invoke_training/ui/config_groups/base_pipeline_config_group.py
================================================
from typing import Any
import gradio as gr
from invoke_training.config.base_pipeline_config import BasePipelineConfig
from invoke_training.config.pipeline_config import PipelineConfig
from invoke_training.ui.config_groups.ui_config_element import UIConfigElement
class BasePipelineConfigGroup(UIConfigElement):
def __init__(self):
self.base_output_dir = gr.Textbox(
label="Base Output Directory",
info="The base output directory where the training outputs (model checkpoints, logs,"
" intermediate predictions) will be written.",
interactive=True,
)
with gr.Row():
with gr.Column():
self.max_train_steps_or_epochs_dropdown = gr.Dropdown(
label="Training Length",
info="Train for a fixed number of gradient update steps or epochs.",
choices=["max_train_steps", "max_train_epochs"],
interactive=True,
)
self.max_train_steps_or_epochs = gr.Number(label="Steps or Epochs", precision=0, interactive=True)
with gr.Column():
self.save_every_n_steps_or_epochs_dropdown = gr.Dropdown(
label="Checkpoint Save Frequency",
info="Save a checkpoint every N gradient update steps or epochs.",
choices=["save_every_n_steps", "save_every_n_epochs"],
interactive=True,
)
self.save_every_n_steps_or_epochs = gr.Number(label="Steps or Epochs", precision=0, interactive=True)
with gr.Column():
self.validate_every_n_steps_or_epochs_dropdown = gr.Dropdown(
label="Validation Frequency",
info="Save validation images every N gradient update steps or epochs.",
choices=["validate_every_n_steps", "validate_every_n_epochs"],
interactive=True,
)
self.validate_every_n_steps_or_epochs = gr.Number(
label="Steps or Epochs", precision=0, interactive=True
)
self.seed = gr.Number(
label="Seed",
info="Set to any constant integer for consistent training results. If set to null, training"
" will be non-deterministic.",
precision=0,
interactive=True,
)
def update_ui_components_with_config_data(self, config: BasePipelineConfig) -> dict[gr.components.Component, Any]:
if config.max_train_epochs is not None:
max_train_steps_or_epochs_dropdown = "max_train_epochs"
max_train_steps_or_epochs = config.max_train_epochs
elif config.max_train_steps is not None:
max_train_steps_or_epochs_dropdown = "max_train_steps"
max_train_steps_or_epochs = config.max_train_steps
else:
raise ValueError("One of max_train_epochs or max_train_steps must be set.")
if config.save_every_n_epochs is not None:
save_every_n_steps_or_epochs_dropdown = "save_every_n_epochs"
save_every_n_steps_or_epochs = config.save_every_n_epochs
elif config.save_every_n_steps is not None:
save_every_n_steps_or_epochs_dropdown = "save_every_n_steps"
save_every_n_steps_or_epochs = config.save_every_n_steps
else:
raise ValueError("One of save_every_n_epochs or save_every_n_steps must be set.")
if config.validate_every_n_epochs is not None:
validate_every_n_steps_or_epochs_dropdown = "validate_every_n_epochs"
validate_every_n_steps_or_epochs = config.validate_every_n_epochs
elif config.validate_every_n_steps is not None:
validate_every_n_steps_or_epochs_dropdown = "validate_every_n_steps"
validate_every_n_steps_or_epochs = config.validate_every_n_steps
else:
raise ValueError("One of validate_every_n_epochs or validate_every_n_steps must be set.")
return {
self.seed: config.seed,
self.base_output_dir: config.base_output_dir,
self.max_train_steps_or_epochs_dropdown: max_train_steps_or_epochs_dropdown,
self.max_train_steps_or_epochs: max_train_steps_or_epochs,
self.save_every_n_steps_or_epochs_dropdown: save_every_n_steps_or_epochs_dropdown,
self.save_every_n_steps_or_epochs: save_every_n_steps_or_epochs,
self.validate_every_n_steps_or_epochs_dropdown: validate_every_n_steps_or_epochs_dropdown,
self.validate_every_n_steps_or_epochs: validate_every_n_steps_or_epochs,
}
def update_config_with_ui_component_data(
self, orig_config: PipelineConfig, ui_data: dict[gr.components.Component, Any]
) -> PipelineConfig:
new_config = orig_config.model_copy(deep=True)
new_config.seed = ui_data.pop(self.seed)
new_config.base_output_dir = ui_data.pop(self.base_output_dir)
if ui_data.pop(self.max_train_steps_or_epochs_dropdown) == "max_train_epochs":
new_config.max_train_epochs = ui_data.pop(self.max_train_steps_or_epochs)
new_config.max_train_steps = None
else:
new_config.max_train_steps = ui_data.pop(self.max_train_steps_or_epochs)
new_config.max_train_epochs = None
if ui_data.pop(self.save_every_n_steps_or_epochs_dropdown) == "save_every_n_epochs":
new_config.save_every_n_epochs = ui_data.pop(self.save_every_n_steps_or_epochs)
new_config.save_every_n_steps = None
else:
new_config.save_every_n_steps = ui_data.pop(self.save_every_n_steps_or_epochs)
new_config.save_every_n_epochs = None
if ui_data.pop(self.validate_every_n_steps_or_epochs_dropdown) == "validate_every_n_epochs":
new_config.validate_every_n_epochs = ui_data.pop(self.validate_every_n_steps_or_epochs)
new_config.validate_every_n_steps = None
else:
new_config.validate_every_n_steps = ui_data.pop(self.validate_every_n_steps_or_epochs)
new_config.validate_every_n_epochs = None
return new_config
================================================
FILE: src/invoke_training/ui/config_groups/dataset_config_group.py
================================================
from typing import Any
import gradio as gr
from invoke_training.config.data.dataset_config import (
HFHubImageCaptionDatasetConfig,
ImageCaptionDatasetConfig,
ImageCaptionDirDatasetConfig,
ImageCaptionJsonlDatasetConfig,
ImageDirDatasetConfig,
)
from invoke_training.ui.config_groups.ui_config_element import UIConfigElement
ALL_DATASET_TYPES = [
"HF_HUB_IMAGE_CAPTION_DATASET",
"IMAGE_CAPTION_JSONL_DATASET",
"IMAGE_CAPTION_DIR_DATASET",
"IMAGE_DIR_DATASET",
]
class HFHubImageCaptionDatasetConfigGroup(UIConfigElement):
def __init__(self):
self.dataset_name = gr.Textbox(
label="Dataset Name", info="Hugging Face Dataset Name (e.g., owner/RepoID).", interactive=True
)
with gr.Row():
self.dataset_config_name = gr.Textbox(
label="Dataset Config Name (Optional)",
info="The Hugging Face dataset config name. Leave as None if there's only one config.",
interactive=True,
)
with gr.Row():
self.hf_cache_dir = gr.Textbox(
label="Cache Directory",
info="The Hugging Face cache directory to use for dataset downloads. If None, the default value"
" will be used (usually '~/.cache/huggingface/datasets').",
interactive=True,
)
# self.image_column = gr.Textbox(label="image_column", interactive=True)
# self.caption_column = gr.Textbox(label="caption_column", interactive=True)
def update_ui_components_with_config_data(
self, config: HFHubImageCaptionDatasetConfig | None
) -> dict[gr.components.Component, Any]:
return {
self.dataset_name: config.dataset_name if config else "",
self.dataset_config_name: config.dataset_config_name if config else None,
self.hf_cache_dir: config.hf_cache_dir if config else None,
# self.image_column: config.image_column,
# self.caption_column: config.caption_column,
}
def update_config_with_ui_component_data(
self, orig_config: HFHubImageCaptionDatasetConfig | None, ui_data: dict[gr.components.Component, Any]
) -> HFHubImageCaptionDatasetConfig:
assert orig_config is None
# new_config = orig_config.model_copy(deep=True)
new_config = HFHubImageCaptionDatasetConfig(
dataset_name=ui_data.pop(self.dataset_name),
dataset_config_name=ui_data.pop(self.dataset_config_name) or None,
hf_cache_dir=ui_data.pop(self.hf_cache_dir) or None,
# image_column=ui_data.pop(self.image_column),
# caption_column=ui_data.pop(self.caption_column),
)
return new_config
class ImageCaptionJsonlDatasetConfigGroup(UIConfigElement):
def __init__(self):
self.jsonl_path = gr.Textbox(label="jsonl_path", info="Path to the dataset `.jsonl` file.", interactive=True)
self.image_column = gr.Textbox(
label="image_column",
info="The name of the field in the `.jsonl` containing image file paths.",
interactive=True,
)
self.caption_column = gr.Textbox(
label="caption_column",
info="The name of the field in the `.jsonl` containing image captions.",
interactive=True,
)
self.keep_in_memory = gr.Checkbox(
label="keep_in_memory",
info="If True, the entire dataset will be kept in RAM. This increases speed for small datasets at the "
"cost of higher RAM usage.",
interactive=True,
)
def update_ui_components_with_config_data(
self, config: ImageCaptionJsonlDatasetConfig | None
) -> dict[gr.components.Component, Any]:
if config is None:
# We just construct this so that we can use its default values.
config = ImageCaptionJsonlDatasetConfig(jsonl_path="")
return {
self.jsonl_path: config.jsonl_path,
self.image_column: config.image_column,
self.caption_column: config.caption_column,
self.keep_in_memory: config.keep_in_memory,
}
def update_config_with_ui_component_data(
self, orig_config: ImageCaptionJsonlDatasetConfig | None, ui_data: dict[gr.components.Component, Any]
) -> ImageCaptionJsonlDatasetConfig:
assert orig_config is None
# new_config = orig_config.model_copy(deep=True)
new_config = ImageCaptionJsonlDatasetConfig(
jsonl_path=ui_data.pop(self.jsonl_path),
image_column=ui_data.pop(self.image_column),
caption_column=ui_data.pop(self.caption_column),
keep_in_memory=ui_data.pop(self.keep_in_memory),
)
return new_config
class ImageCaptionDirDatasetConfigGroup(UIConfigElement):
def __init__(self):
with gr.Row():
self.dataset_dir = gr.Textbox(
label="dataset_dir", info="The path to the dataset directory.", interactive=True
)
with gr.Row():
self.keep_in_memory = gr.Checkbox(
label="keep_in_memory",
info="If True, the entire dataset will be kept in RAM. This increases speed for small datasets at the "
"cost of higher RAM usage.",
interactive=True,
)
def update_ui_components_with_config_data(
self, config: ImageCaptionDirDatasetConfig | None
) -> dict[gr.components.Component, Any]:
return {
self.dataset_dir: config.dataset_dir if config else "",
self.keep_in_memory: config.keep_in_memory if config else False,
}
def update_config_with_ui_component_data(
self, orig_config: ImageCaptionDirDatasetConfig | None, ui_data: dict[gr.components.Component, Any]
) -> ImageCaptionDirDatasetConfig:
assert orig_config is None
# new_config = orig_config.model_copy(deep=True)
new_config = ImageCaptionDirDatasetConfig(
dataset_dir=ui_data.pop(self.dataset_dir), keep_in_memory=ui_data.pop(self.keep_in_memory)
)
return new_config
class ImageDirDatasetConfigGroup(UIConfigElement):
def __init__(self):
with gr.Row():
self.dataset_dir = gr.Textbox(
label="dataset_dir", info="The path to the dataset directory.", interactive=True
)
with gr.Row():
self.keep_in_memory = gr.Checkbox(
label="keep_in_memory",
info="If True, the entire dataset will be kept in RAM. This increases speed for small datasets at the "
"cost of higher RAM usage.",
interactive=True,
)
def update_ui_components_with_config_data(
self, config: ImageDirDatasetConfig | None
) -> dict[gr.components.Component, Any]:
return {
self.dataset_dir: config.dataset_dir if config else "",
self.keep_in_memory: config.keep_in_memory if config else False,
}
def update_config_with_ui_component_data(
self, orig_config: ImageDirDatasetConfig | None, ui_data: dict[gr.components.Component, Any]
) -> ImageDirDatasetConfig:
assert orig_config is None
# new_config = orig_config.model_copy(deep=True)
new_config = ImageDirDatasetConfig(
dataset_dir=ui_data.pop(self.dataset_dir), keep_in_memory=ui_data.pop(self.keep_in_memory)
)
return new_config
class DatasetConfigGroup(UIConfigElement):
def __init__(self, allowed_types: list[str]):
self.type = gr.Dropdown(
choices=[t for t in ALL_DATASET_TYPES if t in allowed_types],
label="Dataset Type",
info="The type of dataset to use for training. See "
"https://invoke-ai.github.io/invoke-training/concepts/dataset_formats/ for a description of each format.",
interactive=True,
)
with gr.Group() as hf_hub_image_caption_dataset_config_group:
self.hf_hub_image_caption_dataset_config = HFHubImageCaptionDatasetConfigGroup()
self.hf_hub_image_caption_dataset_config_group = hf_hub_image_caption_dataset_config_group
with gr.Group() as image_caption_jsonl_dataset_config_group:
self.image_caption_jsonl_dataset_config = ImageCaptionJsonlDatasetConfigGroup()
self.image_caption_jsonl_dataset_config_group = image_caption_jsonl_dataset_config_group
with gr.Group() as image_caption_dir_dataset_config_group:
self.image_caption_dir_dataset_config = ImageCaptionDirDatasetConfigGroup()
self.image_caption_dir_dataset_config_group = image_caption_dir_dataset_config_group
with gr.Group() as image_dir_dataset_config_group:
self.image_dir_dataset_config = ImageDirDatasetConfigGroup()
self.image_dir_dataset_config_group = image_dir_dataset_config_group
self.type.change(
self._on_type_change,
inputs=[self.type],
outputs=[
self.hf_hub_image_caption_dataset_config_group,
self.image_caption_jsonl_dataset_config_group,
self.image_caption_dir_dataset_config_group,
self.image_dir_dataset_config_group,
],
)
def _on_type_change(self, type: str):
return {
self.hf_hub_image_caption_dataset_config_group: gr.Group(visible=type == "HF_HUB_IMAGE_CAPTION_DATASET"),
self.image_caption_jsonl_dataset_config_group: gr.Group(visible=type == "IMAGE_CAPTION_JSONL_DATASET"),
self.image_caption_dir_dataset_config_group: gr.Group(visible=type == "IMAGE_CAPTION_DIR_DATASET"),
self.image_dir_dataset_config_group: gr.Group(visible=type == "IMAGE_DIR_DATASET"),
}
def update_ui_components_with_config_data(
self, config: ImageCaptionDatasetConfig
) -> dict[gr.components.Component, Any]:
update_dict = {
self.type: config.type,
self.hf_hub_image_caption_dataset_config_group: gr.Group(
visible=config.type == "HF_HUB_IMAGE_CAPTION_DATASET"
),
self.image_caption_jsonl_dataset_config_group: gr.Group(
visible=config.type == "IMAGE_CAPTION_JSONL_DATASET"
),
self.image_caption_dir_dataset_config_group: gr.Group(visible=config.type == "IMAGE_CAPTION_DIR_DATASET"),
self.image_dir_dataset_config_group: gr.Group(visible=config.type == "IMAGE_DIR_DATASET"),
}
update_dict.update(
self.hf_hub_image_caption_dataset_config.update_ui_components_with_config_data(
config if config.type == "HF_HUB_IMAGE_CAPTION_DATASET" else None
)
)
update_dict.update(
self.image_caption_jsonl_dataset_config.update_ui_components_with_config_data(
config if config.type == "IMAGE_CAPTION_JSONL_DATASET" else None
)
)
update_dict.update(
self.image_caption_dir_dataset_config.update_ui_components_with_config_data(
config if config.type == "IMAGE_CAPTION_DIR_DATASET" else None
)
)
update_dict.update(
self.image_dir_dataset_config.update_ui_components_with_config_data(
config if config.type == "IMAGE_DIR_DATASET" else None
)
)
return update_dict
def update_config_with_ui_component_data(
self, orig_config: ImageCaptionDatasetConfig, ui_data: dict[gr.components.Component, Any]
) -> ImageCaptionDatasetConfig:
# TODO: Use orig_config.
new_config_hf_hub = self.hf_hub_image_caption_dataset_config.update_config_with_ui_component_data(None, ui_data)
new_config_jsonl = self.image_caption_jsonl_dataset_config.update_config_with_ui_component_data(None, ui_data)
new_config_image_caption_dir = self.image_caption_dir_dataset_config.update_config_with_ui_component_data(
None, ui_data
)
new_config_image_dir = self.image_dir_dataset_config.update_config_with_ui_component_data(None, ui_data)
type = ui_data.pop(self.type)
if type == "HF_HUB_IMAGE_CAPTION_DATASET":
new_config = new_config_hf_hub
elif type == "IMAGE_CAPTION_JSONL_DATASET":
new_config = new_config_jsonl
elif type == "IMAGE_CAPTION_DIR_DATASET":
new_config = new_config_image_caption_dir
elif type == "IMAGE_DIR_DATASET":
new_config = new_config_image_dir
else:
raise ValueError(f"Unknown dataset type: {type}")
return new_config
================================================
FILE: src/invoke_training/ui/config_groups/flux_lora_config_group.py
================================================
import typing
import gradio as gr
from invoke_training.pipelines.flux.lora.config import FluxLoraConfig
from invoke_training.ui.config_groups.base_pipeline_config_group import BasePipelineConfigGroup
from invoke_training.ui.config_groups.image_caption_sd_data_loader_config_group import (
ImageCaptionSDDataLoaderConfigGroup,
)
from invoke_training.ui.config_groups.optimizer_config_group import OptimizerConfigGroup
from invoke_training.ui.config_groups.ui_config_element import UIConfigElement
from invoke_training.ui.utils.utils import get_typing_literal_options
class FluxLoraConfigGroup(UIConfigElement):
def __init__(self):
"""The Flux LoRA configs."""
gr.Markdown("## Basic Configs")
with gr.Row():
with gr.Column(scale=1):
with gr.Tab("Base Model"):
self.model = gr.Textbox(
label="Model",
info="The base model. Can be a Hugging Face Hub model name, or a path to a local model (in "
"diffusers or checkpoint format).",
type="text",
interactive=True,
)
# Flux model doesn't use hf_variant
with gr.Column(scale=3):
with gr.Tab("Training Outputs"):
self.base_pipeline_config_group = BasePipelineConfigGroup()
self.max_checkpoints = gr.Number(
label="Maximum Number of Checkpoints",
info="The maximum number of checkpoints to keep on disk from this training run. Earlier "
"checkpoints will be deleted to respect this limit.",
interactive=True,
precision=0,
)
gr.Markdown("## Data Configs")
self.image_caption_sd_data_loader_config_group = ImageCaptionSDDataLoaderConfigGroup()
gr.Markdown("## Optimizer Configs")
self.optimizer_config_group = OptimizerConfigGroup()
gr.Markdown("## Scheduler Configs")
with gr.Row():
with gr.Column():
self.lr_scheduler = gr.Dropdown(
label="Learning Rate Scheduler",
choices=get_typing_literal_options(FluxLoraConfig, "lr_scheduler"),
interactive=True,
)
self.lr_warmup_steps = gr.Number(
label="Learning Rate Warmup Steps",
info="Number of steps for the warmup in the lr scheduler.",
interactive=True,
precision=0,
)
gr.Markdown("## General Training Configs")
with gr.Tab("Core"):
with gr.Row():
self.train_transformer = gr.Checkbox(label="Train Transformer", interactive=True)
with gr.Row():
self.transformer_learning_rate = gr.Number(
label="Transformer Learning Rate",
info="The transformer learning rate. Set to 0 or leave empty to inherit from the base optimizer "
"learning rate.",
interactive=True,
)
with gr.Row():
self.gradient_accumulation_steps = gr.Number(
label="Gradient Accumulation Steps",
info="Number of updates steps to accumulate before performing a backward/update pass.",
interactive=True,
precision=0,
)
self.gradient_checkpointing = gr.Checkbox(
label="Gradient Checkpointing",
info="Whether to use gradient checkpointing to save memory at the expense of slower backward pass.",
interactive=True,
)
# Training/saving/validating steps/epochs are handled by BasePipelineConfigGroup
with gr.Tab("Advanced"):
with gr.Column():
self.lora_rank_dim = gr.Number(
label="LoRA Rank Dim",
info="The rank dimension to use for the LoRA layers. Increasing the rank dimension increases"
" the model's expressivity, but also increases the size of the generated LoRA model.",
interactive=True,
precision=0,
)
self.min_snr_gamma = gr.Number(
label="Minimum SNR Gamma",
info="min_snr_gamma acts like an an upper bound on the weight of samples with low noise "
"levels. If None, then Min-SNR weighting will not be applied. If enabled, the recommended "
"value is min_snr gamma = 5.0.",
interactive=True,
)
self.max_grad_norm = gr.Number(
label="Max Gradient Norm",
info="Max gradient norm for clipping. Set to 0 or leave empty for no clipping (null).",
interactive=True,
)
self.train_batch_size = gr.Number(
label="Batch Size",
info="The Training Batch Size - Higher values require increasing amounts of VRAM.",
precision=0,
interactive=True,
)
self.weight_dtype = gr.Dropdown(
label="Weight Data Type",
choices=get_typing_literal_options(FluxLoraConfig, "weight_dtype"),
info="The data type to use for model weights during training.",
interactive=True,
)
self.mixed_precision = gr.Dropdown(
label="Mixed Precision",
choices=get_typing_literal_options(FluxLoraConfig, "mixed_precision"),
info="The mixed precision mode to use.",
interactive=True,
)
self.lora_checkpoint_format = gr.Dropdown(
label="LoRA Checkpoint Format",
choices=get_typing_literal_options(FluxLoraConfig, "lora_checkpoint_format"),
info="The format of the LoRA checkpoint to save.",
interactive=True,
)
self.timestep_sampler = gr.Dropdown(
label="Timestep Sampler",
choices=get_typing_literal_options(FluxLoraConfig, "timestep_sampler"),
info="The timestep sampler to use.",
interactive=True,
)
self.discrete_flow_shift = gr.Number(
label="Discrete Flow Shift",
info="The shift parameter for the discrete flow. Only used if timestep_sampler is 'shift'.",
interactive=True,
)
self.sigmoid_scale = gr.Number(
label="Sigmoid Scale",
info="The scale parameter for the sigmoid function. Only used if timestep_sampler is 'shift'.",
interactive=True,
)
self.lora_scale = gr.Number(
label="LoRA Scale",
info="The scale parameter for the LoRA layers.",
interactive=True,
)
self.guidance_scale = gr.Number(
label="Guidance Scale",
info="The guidance scale for the Flux model.",
interactive=True,
)
self.use_masks = gr.Checkbox(
label="Use Masks",
info="If True, image masks will be applied to weight the loss during training. The dataset must "
"contain masks for this feature to be used.",
interactive=True,
)
self.prediction_type = gr.Dropdown(
label="Prediction Type",
choices=["epsilon", "v_prediction", None],
info="The prediction type that will be used for training.",
interactive=True,
)
gr.Markdown("## Validation")
with gr.Group():
self.validation_prompts = gr.Textbox(
label="Validation Prompts",
info="Enter one validation prompt per line. Optionally, add negative prompts after a '[NEG]' "
"delimiter. For example: `positive prompt[NEG]negative prompt`. ",
lines=5,
interactive=True,
)
self.num_validation_images_per_prompt = gr.Number(
label="# of Validation Images to Generate per Prompt", precision=0, interactive=True
)
def get_ui_output_components(self) -> list[gr.components.Component]:
# Get our own components
components = [
self.model,
self.train_transformer,
self.transformer_learning_rate,
self.gradient_accumulation_steps,
self.gradient_checkpointing,
self.lr_scheduler,
self.lr_warmup_steps,
self.lora_rank_dim,
self.min_snr_gamma,
self.max_grad_norm,
self.train_batch_size,
self.weight_dtype,
self.mixed_precision,
self.lora_checkpoint_format,
self.timestep_sampler,
self.discrete_flow_shift,
self.sigmoid_scale,
self.lora_scale,
self.guidance_scale,
self.use_masks,
self.prediction_type,
# These are not UI components but need to be preserved
# self.flux_lora_target_modules,
# self.text_encoder_lora_target_modules,
self.validation_prompts,
self.num_validation_images_per_prompt,
self.max_checkpoints,
]
# Add components from nested config groups
components.extend(self.base_pipeline_config_group.get_ui_output_components())
components.extend(self.image_caption_sd_data_loader_config_group.get_ui_output_components())
components.extend(self.optimizer_config_group.get_ui_output_components())
return components
def update_ui_components_with_config_data(
self, config: FluxLoraConfig
) -> dict[gr.components.Component, typing.Any]:
try:
update_dict = {
self.model: config.model,
self.train_transformer: config.train_transformer,
self.transformer_learning_rate: config.transformer_learning_rate,
self.gradient_accumulation_steps: config.gradient_accumulation_steps,
self.gradient_checkpointing: config.gradient_checkpointing,
self.lr_scheduler: config.lr_scheduler,
self.lr_warmup_steps: config.lr_warmup_steps,
self.lora_rank_dim: config.lora_rank_dim,
self.min_snr_gamma: config.min_snr_gamma,
self.max_grad_norm: config.max_grad_norm,
self.train_batch_size: config.train_batch_size,
self.weight_dtype: config.weight_dtype,
self.mixed_precision: config.mixed_precision,
self.lora_checkpoint_format: config.lora_checkpoint_format,
self.timestep_sampler: config.timestep_sampler,
self.discrete_flow_shift: config.discrete_flow_shift,
self.sigmoid_scale: config.sigmoid_scale,
self.lora_scale: config.lora_scale,
self.guidance_scale: config.guidance_scale,
self.use_masks: config.use_masks,
self.prediction_type: config.prediction_type,
self.validation_prompts: config.validation_prompts,
self.num_validation_images_per_prompt: config.num_validation_images_per_prompt,
self.max_checkpoints: config.max_checkpoints,
}
# Update with nested config groups
try:
update_dict.update(self.base_pipeline_config_group.update_ui_components_with_config_data(config))
except Exception as e:
print(f"Error updating base pipeline config: {e}")
try:
update_dict.update(self.optimizer_config_group.update_ui_components_with_config_data(config.optimizer))
except Exception as e:
print(f"Error updating optimizer config: {e}")
try:
update_dict.update(
self.image_caption_sd_data_loader_config_group.update_ui_components_with_config_data(
config.data_loader
)
)
except Exception as e:
print(f"Error updating data loader config: {e}")
# Sanity check to catch if we accidentally forget to update a UI component.
# We'll skip this check for now as it's causing issues with nested components
# assert set(update_dict.keys()) == set(self.get_ui_output_components())
return update_dict
except Exception as e:
print(f"Error in update_ui_components_with_config_data: {e}")
# Return a minimal update dict to avoid UI errors
return {self.model: config.model}
def update_config_with_ui_component_data( # noqa: C901
self, orig_config: FluxLoraConfig, ui_data: dict[gr.components.Component, typing.Any]
) -> FluxLoraConfig:
try:
# Handle the case where orig_config might be None
if orig_config is None:
from invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig
from invoke_training.pipelines.flux.lora.config import FluxLoraConfig
# Create a default config
orig_config = FluxLoraConfig(
model="black-forest-labs/FLUX.1-dev",
optimizer=AdamOptimizerConfig(),
)
new_config = orig_config.model_copy(deep=True)
# Create a copy of ui_data to avoid modifying the original
ui_data_copy = ui_data.copy()
# Helper function to safely pop values from ui_data
def safe_pop(component, default=None):
try:
return ui_data_copy.pop(component)
except (KeyError, TypeError) as e:
print(f"Error popping {component}: {e}")
return default
# Set basic properties
new_config.model = safe_pop(self.model, new_config.model)
new_config.train_transformer = safe_pop(self.train_transformer, new_config.train_transformer)
# Note: train_text_encoder and text_encoder_learning_rate are not supported for Flux LoRA
transformer_lr_value = safe_pop(self.transformer_learning_rate, new_config.transformer_learning_rate)
new_config.transformer_learning_rate = None if transformer_lr_value == 0 else transformer_lr_value
new_config.gradient_accumulation_steps = safe_pop(
self.gradient_accumulation_steps, new_config.gradient_accumulation_steps
)
new_config.gradient_checkpointing = safe_pop(self.gradient_checkpointing, new_config.gradient_checkpointing)
# Training/saving/validating steps/epochs are handled by BasePipelineConfigGroup
new_config.lr_scheduler = safe_pop(self.lr_scheduler, new_config.lr_scheduler)
new_config.lr_warmup_steps = safe_pop(self.lr_warmup_steps, new_config.lr_warmup_steps)
new_config.lora_rank_dim = safe_pop(self.lora_rank_dim, new_config.lora_rank_dim)
new_config.min_snr_gamma = safe_pop(self.min_snr_gamma, new_config.min_snr_gamma)
max_grad_norm_value = safe_pop(self.max_grad_norm, new_config.max_grad_norm)
new_config.max_grad_norm = None if max_grad_norm_value == 0 else max_grad_norm_value
new_config.train_batch_size = safe_pop(self.train_batch_size, new_config.train_batch_size)
new_config.weight_dtype = safe_pop(self.weight_dtype, new_config.weight_dtype)
new_config.mixed_precision = safe_pop(self.mixed_precision, new_config.mixed_precision)
new_config.lora_checkpoint_format = safe_pop(self.lora_checkpoint_format, new_config.lora_checkpoint_format)
new_config.timestep_sampler = safe_pop(self.timestep_sampler, new_config.timestep_sampler)
new_config.discrete_flow_shift = safe_pop(self.discrete_flow_shift, new_config.discrete_flow_shift)
new_config.sigmoid_scale = safe_pop(self.sigmoid_scale, new_config.sigmoid_scale)
new_config.lora_scale = safe_pop(self.lora_scale, new_config.lora_scale)
new_config.guidance_scale = safe_pop(self.guidance_scale, new_config.guidance_scale)
new_config.use_masks = safe_pop(self.use_masks, new_config.use_masks)
new_config.prediction_type = safe_pop(self.prediction_type, new_config.prediction_type)
new_config.max_checkpoints = safe_pop(self.max_checkpoints, new_config.max_checkpoints)
# Preserve the target modules from the original config
# These are not UI components but need to be preserved
if hasattr(orig_config, "flux_lora_target_modules") and orig_config.flux_lora_target_modules:
new_config.flux_lora_target_modules = orig_config.flux_lora_target_modules
if (
hasattr(orig_config, "text_encoder_lora_target_modules")
and orig_config.text_encoder_lora_target_modules
):
new_config.text_encoder_lora_target_modules = orig_config.text_encoder_lora_target_modules
# Handle validation prompts
try:
validation_prompts_text = safe_pop(self.validation_prompts, "")
positive_prompts = validation_prompts_text
new_config.validation_prompts = positive_prompts
except Exception as e:
print(f"Error processing validation prompts: {e}")
new_config.num_validation_images_per_prompt = safe_pop(
self.num_validation_images_per_prompt, new_config.num_validation_images_per_prompt
)
# Update nested configs
try:
data_loader_config_group = self.image_caption_sd_data_loader_config_group
# Handle the case where data_loader might be None
new_config.data_loader = data_loader_config_group.update_config_with_ui_component_data(
new_config.data_loader, ui_data_copy
)
except Exception as e:
print(f"Error updating data loader config: {e}")
try:
base_pipeline_group = self.base_pipeline_config_group
new_config = base_pipeline_group.update_config_with_ui_component_data(new_config, ui_data_copy)
except Exception as e:
print(f"Error updating base pipeline config: {e}")
try:
# Handle the case where optimizer might be None
if new_config.optimizer is None:
from invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig
new_config.optimizer = AdamOptimizerConfig()
new_config.optimizer = self.optimizer_config_group.update_config_with_ui_component_data(
new_config.optimizer, ui_data_copy
)
except Exception as e:
print(f"Error updating optimizer config: {e}")
# We're more lenient with the assertion now
if len(ui_data_copy) > 0:
print(f"Warning: {len(ui_data_copy)} UI components were not transferred to the config")
return new_config
except Exception as e:
print(f"Error in update_config_with_ui_component_data: {e}")
# Return the original config to avoid errors
return orig_config
================================================
FILE: src/invoke_training/ui/config_groups/image_caption_sd_data_loader_config_group.py
================================================
from typing import Any
import gradio as gr
from invoke_training.config.data.data_loader_config import ImageCaptionSDDataLoaderConfig
from invoke_training.ui.config_groups.aspect_ratio_bucket_config_group import AspectRatioBucketConfigGroup
from invoke_training.ui.config_groups.dataset_config_group import DatasetConfigGroup
from invoke_training.ui.config_groups.ui_config_element import UIConfigElement
class ImageCaptionSDDataLoaderConfigGroup(UIConfigElement):
def __init__(self):
with gr.Tab("Data Source Configs"):
with gr.Row():
with gr.Column(scale=1):
with gr.Group():
self.dataset = DatasetConfigGroup(
allowed_types=[
"HF_HUB_IMAGE_CAPTION_DATASET",
"IMAGE_CAPTION_JSONL_DATASET",
"IMAGE_CAPTION_DIR_DATASET",
]
)
with gr.Column(scale=3):
with gr.Tab("Data Loading Configs"):
with gr.Group():
with gr.Row():
self.resolution = gr.Number(
label="Resolution",
info="The resolution for input images. All of the images in the dataset will be"
" resized to this resolution unless the aspect_ratio_buckets config is set.",
precision=0,
interactive=True,
)
self.dataloader_num_workers = gr.Number(
label="Dataloading Workers",
info="Number of subprocesses to use for data loading. 0 means that the data will"
" be loaded in the main process.",
precision=0,
interactive=True,
)
with gr.Row():
self.center_crop = gr.Checkbox(
label="Center Crop",
info="If set, input images will be center-cropped to the target resolution."
" Otherwise, input images will be randomly cropped to the target resolution.",
interactive=True,
)
self.random_flip = gr.Checkbox(
label="Random Flip",
info="If set, random flip augmentations will be applied to input images.",
interactive=True,
)
self.caption_prefix = gr.Textbox(
label="Caption Prefix",
info="A prefix that will be prepended to all captions."
" If None, no prefix will be added.",
interactive=True,
)
with gr.Tab("Aspect Ratio Bucketing Configs"):
self.aspect_ratio_bucket_config_group = AspectRatioBucketConfigGroup()
def update_ui_components_with_config_data(
self, config: ImageCaptionSDDataLoaderConfig
) -> dict[gr.components.Component, Any]:
update_dict = {
self.resolution: config.resolution,
self.center_crop: config.center_crop,
self.random_flip: config.random_flip,
self.caption_prefix: config.caption_prefix,
self.dataloader_num_workers: config.dataloader_num_workers,
}
update_dict.update(self.dataset.update_ui_components_with_config_data(config.dataset))
update_dict.update(
self.aspect_ratio_bucket_config_group.update_ui_components_with_config_data(config.aspect_ratio_buckets)
)
return update_dict
def update_config_with_ui_component_data(
self,
orig_config: ImageCaptionSDDataLoaderConfig,
ui_data: dict[gr.components.Component, Any],
) -> ImageCaptionSDDataLoaderConfig:
# Handle the case where orig_config is None
if orig_config is None:
from invoke_training.config.data.data_loader_config import (
AspectRatioBucketConfig,
ImageCaptionSDDataLoaderConfig,
)
from invoke_training.config.data.dataset_config import ImageCaptionJsonlDatasetConfig
# Create a default config
orig_config = ImageCaptionSDDataLoaderConfig(
type="IMAGE_CAPTION_SD_DATA_LOADER",
dataset=ImageCaptionJsonlDatasetConfig(jsonl_path=""),
aspect_ratio_buckets=AspectRatioBucketConfig(),
resolution=512,
center_crop=False,
random_flip=True,
caption_prefix=None,
dataloader_num_workers=4,
)
new_config = orig_config.model_copy(deep=True)
new_config.dataset = self.dataset.update_config_with_ui_component_data(orig_config.dataset, ui_data)
new_config.aspect_ratio_buckets = self.aspect_ratio_bucket_config_group.update_config_with_ui_component_data(
orig_config.aspect_ratio_buckets, ui_data
)
new_config.resolution = ui_data.pop(self.resolution)
new_config.center_crop = ui_data.pop(self.center_crop)
new_config.random_flip = ui_data.pop(self.random_flip)
new_config.caption_prefix = ui_data.pop(self.caption_prefix) or None
new_config.dataloader_num_workers = ui_data.pop(self.dataloader_num_workers)
return new_config
================================================
FILE: src/invoke_training/ui/config_groups/optimizer_config_group.py
================================================
from typing import Any
import gradio as gr
from invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig, ProdigyOptimizerConfig
from invoke_training.ui.config_groups.ui_config_element import UIConfigElement
OptimizerConfig = AdamOptimizerConfig | ProdigyOptimizerConfig
class AdamOptimizerConfigGroup(UIConfigElement):
def __init__(self):
with gr.Tab("Core"):
with gr.Row():
self.learning_rate = gr.Number(
label="Learning Rate",
info="Initial learning rate to use (after the potential warmup period). Note that in some training "
"pipelines this can be overriden for a specific group of params.",
interactive=True,
)
self.use_8bit = gr.Checkbox(
label="Use 8-bit",
info="Use 8-bit Adam optimizer to reduce VRAM requirements. (Requires bitsandbytes.)",
interactive=True,
)
with gr.Tab("Advanced"):
with gr.Row():
self.beta1 = gr.Number(label="beta1", interactive=True)
self.beta2 = gr.Number(label="beta2", interactive=True)
with gr.Row():
self.weight_decay = gr.Number(label="Weight Decay", interactive=True)
self.epsilon = gr.Number(label="epsilon", interactive=True)
def update_ui_components_with_config_data(self, config: AdamOptimizerConfig) -> dict[gr.components.Component, Any]:
return {
self.learning_rate: config.learning_rate,
self.beta1: config.beta1,
self.beta2: config.beta2,
self.weight_decay: config.weight_decay,
self.epsilon: config.epsilon,
self.use_8bit: config.use_8bit,
}
def update_config_with_ui_component_data(
self, orig_config: AdamOptimizerConfig | None, ui_data: dict
) -> OptimizerConfig:
assert orig_config is None
return AdamOptimizerConfig(
learning_rate=ui_data.pop(self.learning_rate),
beta1=ui_data.pop(self.beta1),
beta2=ui_data.pop(self.beta2),
weight_decay=ui_data.pop(self.weight_decay),
epsilon=ui_data.pop(self.epsilon),
use_8bit=ui_data.pop(self.use_8bit),
)
class ProdigyOptimizerConfigGroup(UIConfigElement):
def __init__(self):
with gr.Tab("Core"):
with gr.Row():
self.learning_rate = gr.Number(
label="Learning Rate",
info="The learning rate. For the Prodigy optimizer, the learning rate is adjusted dynamically. A "
"value of 1.0 is recommended. Note that in some pipelines this can be overriden for specific "
"groups of parameters.",
interactive=True,
)
with gr.Tab("Advanced"):
with gr.Row():
self.weight_decay = gr.Number(label="Weight Decay", interactive=True)
with gr.Row():
self.use_bias_correction = gr.Checkbox(label="Bias Correction", interactive=True)
self.safeguard_warmup = gr.Checkbox(label="Safeguard Warmup", interactive=True)
def update_ui_components_with_config_data(
self, config: ProdigyOptimizerConfig
) -> dict[gr.components.Component, Any]:
return {
self.learning_rate: config.learning_rate,
self.weight_decay: config.weight_decay,
self.use_bias_correction: config.use_bias_correction,
self.safeguard_warmup: config.safeguard_warmup,
}
def update_config_with_ui_component_data(
self, orig_config: ProdigyOptimizerConfig | None, ui_data: dict
) -> OptimizerConfig:
assert orig_config is None
return ProdigyOptimizerConfig(
learning_rate=ui_data.pop(self.learning_rate),
weight_decay=ui_data.pop(self.weight_decay),
use_bias_correction=ui_data.pop(self.use_bias_correction),
safeguard_warmup=ui_data.pop(self.safeguard_warmup),
)
class OptimizerConfigGroup(UIConfigElement):
def __init__(self):
with gr.Group():
self.optimizer_type = gr.Dropdown(label="optimizer", choices=["AdamW", "Prodigy"], interactive=True)
with gr.Group() as adam_optimizer_config_group:
self.adam_optimizer_config = AdamOptimizerConfigGroup()
self.adam_optimizer_config_group = adam_optimizer_config_group
with gr.Group() as prodigy_optimizer_config_group:
self.prodigy_optimizer_config = ProdigyOptimizerConfigGroup()
self.prodigy_optimizer_config_group = prodigy_optimizer_config_group
self.optimizer_type.change(
self._on_optimizer_type_change,
inputs=[self.optimizer_type],
outputs=[self.adam_optimizer_config_group, self.prodigy_optimizer_config_group],
)
def _on_optimizer_type_change(self, optimizer_type: str):
return {
self.adam_optimizer_config_group: gr.Group(visible=optimizer_type == "AdamW"),
self.prodigy_optimizer_config_group: gr.Group(visible=optimizer_type == "Prodigy"),
}
def update_ui_components_with_config_data(self, config: OptimizerConfig) -> dict[gr.components.Component, Any]:
update_dict = {
self.optimizer_type: config.optimizer_type,
self.adam_optimizer_config_group: gr.Group(visible=config.optimizer_type == "AdamW"),
self.prodigy_optimizer_config_group: gr.Group(visible=config.optimizer_type == "Prodigy"),
}
update_dict.update(
self.adam_optimizer_config.update_ui_components_with_config_data(
config if config.optimizer_type == "AdamW" else AdamOptimizerConfig()
)
)
update_dict.update(
self.prodigy_optimizer_config.update_ui_components_with_config_data(
config if config.optimizer_type == "Prodigy" else ProdigyOptimizerConfig()
)
)
return update_dict
def update_config_with_ui_component_data(self, orig_config: OptimizerConfig, ui_data: dict) -> OptimizerConfig:
# TODO: Use orig_config?
new_config_adam = self.adam_optimizer_config.update_config_with_ui_component_data(None, ui_data)
new_config_prodigy = self.prodigy_optimizer_config.update_config_with_ui_component_data(None, ui_data)
optimizer_type = ui_data.pop(self.optimizer_type)
if optimizer_type == "AdamW":
return new_config_adam
elif optimizer_type == "Prodigy":
return new_config_prodigy
else:
raise ValueError(f"Invalid optimizer type: {optimizer_type}")
================================================
FILE: src/invoke_training/ui/config_groups/sd_lora_config_group.py
================================================
import typing
import gradio as gr
from invoke_training.pipelines.stable_diffusion.lora.config import SdLoraConfig
from invoke_training.ui.config_groups.base_pipeline_config_group import BasePipelineConfigGroup
from invoke_training.ui.config_groups.image_caption_sd_data_loader_config_group import (
ImageCaptionSDDataLoaderConfigGroup,
)
from invoke_training.ui.config_groups.optimizer_config_group import OptimizerConfigGroup
from invoke_training.ui.config_groups.ui_config_element import UIConfigElement
from invoke_training.ui.utils.prompts import (
convert_pos_neg_prompts_to_ui_prompts,
convert_ui_prompts_to_pos_neg_prompts,
)
from invoke_training.ui.utils.utils import get_typing_literal_options
class SdLoraConfigGroup(UIConfigElement):
def __init__(self):
"""The SD_LORA configs."""
gr.Markdown("## Basic Configs")
with gr.Row():
with gr.Column(scale=1):
with gr.Tab("Base Model"):
self.model = gr.Textbox(
label="Model",
info="The base model. Can be a Hugging Face Hub model name, or a path to a local model (in "
"diffusers or checkpoint format).",
type="text",
interactive=True,
)
self.hf_variant = gr.Textbox(
label="Variant",
info="(optional) The Hugging Face hub model variant (e.g., fp16, fp32) to use if the model is a"
" HF Hub model name.",
type="text",
interactive=True,
)
with gr.Column(scale=3):
with gr.Tab("Training Outputs"):
self.base_pipeline_config_group = BasePipelineConfigGroup()
self.max_checkpoints = gr.Number(
label="Maximum Number of Checkpoints",
info="The maximum number of checkpoints to keep on disk from this training run. Earlier "
"checkpoints will be deleted to respect this limit.",
interactive=True,
precision=0,
)
gr.Markdown("## Data Configs")
self.image_caption_sd_data_loader_config_group = ImageCaptionSDDataLoaderConfigGroup()
gr.Markdown("## Optimizer Configs")
self.optimizer_config_group = OptimizerConfigGroup()
gr.Markdown("## Speed / Memory Configs")
with gr.Group():
with gr.Row():
self.gradient_accumulation_steps = gr.Number(
label="Gradient Accumulation Steps",
info="The number of gradient steps to accumulate before each weight update. This is an"
" alternative to increasing the batch size when training with limited VRAM. "
"effective_batch_size = train_batch_size * gradient_accumulation_steps.",
precision=0,
interactive=True,
)
with gr.Row():
self.weight_dtype = gr.Dropdown(
label="Weight Type",
info="The precision of the model weights. Lower precision can speed up training and reduce memory, "
"with increased risk of numerical stability issues. 'bfloat16' is recommended for most use cases "
"if your GPU supports it.",
choices=get_typing_literal_options(SdLoraConfig, "weight_dtype"),
interactive=True,
)
with gr.Row():
self.cache_text_encoder_outputs = gr.Checkbox(
label="Cache Text Encoder Outputs",
info="Cache the text encoder outputs to increase speed. This should not be used when training the "
"text encoder or performing data augmentations that would change the text encoder outputs.",
interactive=True,
)
self.cache_vae_outputs = gr.Checkbox(
label="Cache VAE Outputs",
info="Cache the VAE outputs to increase speed. This should not be used when training the UNet or "
"performing data augmentations that would change the VAE outputs.",
interactive=True,
)
with gr.Row():
self.enable_cpu_offload_during_validation = gr.Checkbox(
label="Enable CPU Offload during Validation",
info="Offload models to the CPU sequentially during validation. This reduces peak VRAM "
"requirements at the cost of slower validation during training.",
interactive=True,
)
self.gradient_checkpointing = gr.Checkbox(
label="Gradient Checkpointing",
info="If True, VRAM requirements are reduced at the cost of ~20% slower training",
interactive=True,
)
gr.Markdown("## General Training Configs")
with gr.Tab("Core"):
with gr.Row():
self.train_unet = gr.Checkbox(label="Train UNet", interactive=True)
self.train_text_encoder = gr.Checkbox(label="Train Text Encoder", interactive=True)
with gr.Row():
self.unet_learning_rate = gr.Number(
label="UNet Learning Rate",
info="The UNet learning rate. Set to 0 or leave empty to inherit from the base optimizer "
"learning rate.",
interactive=True,
)
self.text_encoder_learning_rate = gr.Number(
label="Text Encoder Learning Rate",
info="The text encoder learning rate. Set to 0 or leave empty to inherit from the base optimizer "
"learning rate.",
interactive=True,
)
with gr.Row():
self.lr_scheduler = gr.Dropdown(
label="Learning Rate Scheduler",
choices=get_typing_literal_options(SdLoraConfig, "lr_scheduler"),
interactive=True,
)
self.lr_warmup_steps = gr.Number(
label="Warmup Steps",
info="The number of warmup steps in the "
"learning rate schedule, if applicable to the selected scheduler.",
interactive=True,
)
with gr.Row():
self.use_masks = gr.Checkbox(
label="Use Masks", info="This can only be enabled if the dataset contains masks.", interactive=True
)
with gr.Tab("Advanced"):
with gr.Column():
self.lora_rank_dim = gr.Number(
label="LoRA Rank Dim",
info="The rank dimension to use for the LoRA layers. Increasing the rank dimension"
" increases the model's expressivity, but also increases the size of the generated LoRA model.",
interactive=True,
precision=0,
)
self.min_snr_gamma = gr.Number(
label="Minumum SNR Gamma",
info="min_snr_gamma acts like an an upper bound on the weight of samples with low noise "
"levels. If None, then Min-SNR weighting will not be applied. If enabled, the recommended "
"value is min_snr gamma = 5.0.",
interactive=True,
)
self.max_grad_norm = gr.Number(
label="Max Gradient Norm",
info="Max gradient norm for clipping. Set to 0 or leave empty for no clipping (null).",
interactive=True,
)
self.train_batch_size = gr.Number(
label="Batch Size",
info="The Training Batch Size - Higher values require increasing amounts of VRAM.",
precision=0,
interactive=True,
)
gr.Markdown("## Validation")
with gr.Group():
self.validation_prompts = gr.Textbox(
label="Validation Prompts",
info="Enter one validation prompt per line. Optionally, add negative prompts after a '[NEG]' "
"delimiter. For example: `positive prompt[NEG]negative prompt`. ",
lines=5,
interactive=True,
)
self.num_validation_images_per_prompt = gr.Number(
label="# of Validation Images to Generate per Prompt", precision=0, interactive=True
)
def update_ui_components_with_config_data(self, config: SdLoraConfig) -> dict[gr.components.Component, typing.Any]:
update_dict = {
self.model: config.model,
self.hf_variant: config.hf_variant,
self.max_checkpoints: config.max_checkpoints,
self.train_unet: config.train_unet,
self.unet_learning_rate: config.unet_learning_rate,
self.train_text_encoder: config.train_text_encoder,
self.text_encoder_learning_rate: config.text_encoder_learning_rate,
self.lr_scheduler: config.lr_scheduler,
self.lr_warmup_steps: config.lr_warmup_steps,
self.use_masks: config.use_masks,
self.max_grad_norm: config.max_grad_norm,
self.train_batch_size: config.train_batch_size,
self.cache_text_encoder_outputs: config.cache_text_encoder_outputs,
self.cache_vae_outputs: config.cache_vae_outputs,
self.enable_cpu_offload_during_validation: config.enable_cpu_offload_during_validation,
self.gradient_accumulation_steps: config.gradient_accumulation_steps,
self.weight_dtype: config.weight_dtype,
self.gradient_checkpointing: config.gradient_checkpointing,
self.lora_rank_dim: config.lora_rank_dim,
self.min_snr_gamma: config.min_snr_gamma,
self.validation_prompts: convert_pos_neg_prompts_to_ui_prompts(
config.validation_prompts, config.negative_validation_prompts
),
self.num_validation_images_per_prompt: config.num_validation_images_per_prompt,
}
update_dict.update(
self.image_caption_sd_data_loader_config_group.update_ui_components_with_config_data(config.data_loader)
)
update_dict.update(self.base_pipeline_config_group.update_ui_components_with_config_data(config))
update_dict.update(self.optimizer_config_group.update_ui_components_with_config_data(config.optimizer))
# Sanity check to catch if we accidentally forget to update a UI component.
assert set(update_dict.keys()) == set(self.get_ui_output_components())
return update_dict
def update_config_with_ui_component_data(
self, orig_config: SdLoraConfig, ui_data: dict[gr.components.Component, typing.Any]
) -> SdLoraConfig:
new_config = orig_config.model_copy(deep=True)
new_config.model = ui_data.pop(self.model)
new_config.hf_variant = ui_data.pop(self.hf_variant) or None
new_config.max_checkpoints = ui_data.pop(self.max_checkpoints)
new_config.train_unet = ui_data.pop(self.train_unet)
unet_lr_value = ui_data.pop(self.unet_learning_rate)
new_config.unet_learning_rate = None if unet_lr_value == 0 else unet_lr_value
new_config.train_text_encoder = ui_data.pop(self.train_text_encoder)
text_encoder_lr_value = ui_data.pop(self.text_encoder_learning_rate)
new_config.text_encoder_learning_rate = None if text_encoder_lr_value == 0 else text_encoder_lr_value
new_config.lr_scheduler = ui_data.pop(self.lr_scheduler)
new_config.lr_warmup_steps = ui_data.pop(self.lr_warmup_steps)
new_config.use_masks = ui_data.pop(self.use_masks)
max_grad_norm_value = ui_data.pop(self.max_grad_norm)
new_config.max_grad_norm = None if max_grad_norm_value == 0 else max_grad_norm_value
new_config.train_batch_size = ui_data.pop(self.train_batch_size)
new_config.cache_text_encoder_outputs = ui_data.pop(self.cache_text_encoder_outputs)
new_config.cache_vae_outputs = ui_data.pop(self.cache_vae_outputs)
new_config.enable_cpu_offload_during_validation = ui_data.pop(self.enable_cpu_offload_during_validation)
new_config.gradient_accumulation_steps = ui_data.pop(self.gradient_accumulation_steps)
new_config.weight_dtype = ui_data.pop(self.weight_dtype)
new_config.gradient_checkpointing = ui_data.pop(self.gradient_checkpointing)
new_config.lora_rank_dim = ui_data.pop(self.lora_rank_dim)
new_config.min_snr_gamma = ui_data.pop(self.min_snr_gamma)
new_config.num_validation_images_per_prompt = ui_data.pop(self.num_validation_images_per_prompt)
positive_prompts, negative_prompts = convert_ui_prompts_to_pos_neg_prompts(ui_data.pop(self.validation_prompts))
new_config.validation_prompts = positive_prompts
new_config.negative_validation_prompts = negative_prompts
new_config.data_loader = self.image_caption_sd_data_loader_config_group.update_config_with_ui_component_data(
new_config.data_loader, ui_data
)
new_config = self.base_pipeline_config_group.update_config_with_ui_component_data(new_config, ui_data)
new_config.optimizer = self.optimizer_config_group.update_config_with_ui_component_data(
new_config.optimizer, ui_data
)
# We pop items from ui_data as we use them so that we can sanity check that all the input data was transferred
# to the config.
assert len(ui_data) == 0
return new_config
================================================
FILE: src/invoke_training/ui/config_groups/sd_textual_inversion_config_group.py
================================================
import typing
import gradio as gr
from invoke_training.pipelines.stable_diffusion.textual_inversion.config import SdTextualInversionConfig
from invoke_training.ui.config_groups.base_pipeline_config_group import BasePipelineConfigGroup
from invoke_training.ui.config_groups.optimizer_config_group import OptimizerConfigGroup
from invoke_training.ui.config_groups.textual_inversion_sd_data_loader_config_group import (
TextualInversionSDDataLoaderConfigGroup,
)
from invoke_training.ui.config_groups.ui_config_element import UIConfigElement
from invoke_training.ui.utils.prompts import (
convert_pos_neg_prompts_to_ui_prompts,
convert_ui_prompts_to_pos_neg_prompts,
)
from invoke_training.ui.utils.utils import get_typing_literal_options
class SdTextualInversionConfigGroup(UIConfigElement):
def __init__(self):
"""The SD_TEXTUAL_INVERSION configs."""
gr.Markdown("## Basic Configs")
with gr.Row():
with gr.Column(scale=1):
with gr.Tab("Base Model"):
self.model = gr.Textbox(
label="Model",
info="The base model. Can be a Hugging Face Hub model name, or a path to a local model (in "
"diffusers or checkpoint format).",
type="text",
interactive=True,
)
self.hf_variant = gr.Textbox(
label="Variant",
info="(optional) The Hugging Face hub model variant (e.g., fp16, fp32) to use if the model is a"
" HF Hub model name.",
type="text",
interactive=True,
)
with gr.Column(scale=3):
with gr.Tab("Training Outputs"):
self.base_pipeline_config_group = BasePipelineConfigGroup()
self.max_checkpoints = gr.Number(
label="Maximum Number of Checkpoints",
info="The maximum number of checkpoints to keep on disk from this training run. Earlier "
"checkpoints will be deleted to respect this limit.",
interactive=True,
precision=0,
)
gr.Markdown("## Data Configs")
self.textual_inversion_sd_data_loader_config_group = TextualInversionSDDataLoaderConfigGroup()
gr.Markdown("## Textual Inversion Configs")
self.num_vectors = gr.Number(
label="Num Vectors",
info="The number of TI vectors that will be trained. Can be overriden by 'Initial Phrase'.",
interactive=True,
precision=0,
)
self.placeholder_token = gr.Textbox(
label="Placeholder Token",
info="The special word to associate the learned embeddings with. Choose a unique token that is unlikely to "
"already exist in the tokenizer's vocabulary.",
interactive=True,
)
self.initializer_token = gr.Textbox(
label="Initializer Token",
info="Only one of 'Initializer Token' or 'Initial Phrase' should be set. A vocabulary token to use as an "
"initializer for the placeholder token. It should be a single word that roughly describes the object or "
"style that you're trying to train on. The initializer token ust map to a single tokenizer token.",
interactive=True,
)
self.initial_phrase = gr.Textbox(
label="Initial Phrase",
info="Only one of 'Initializer Token' or 'Initial Phrase' should be set. A phrase that will be used to "
"initialize the placeholder token embedding. The phrase will be tokenized, and the corresponding "
"embeddings will be used to initialize the placeholder tokens. The number of embedding vectors will be "
"inferred from the length of the tokenized phrase, so keep the phrase short.",
interactive=True,
)
gr.Markdown("## Optimizer Configs")
self.optimizer_config_group = OptimizerConfigGroup()
gr.Markdown("## Speed / Memory Configs")
with gr.Group():
with gr.Row():
self.gradient_accumulation_steps = gr.Number(
label="Gradient Accumulation Steps",
info="The number of gradient steps to accumulate before each weight update. This is an"
" alternative to increasing the batch size when training with limited VRAM."
"effective_batch_size = train_batch_size * gradient_accumulation_steps.",
precision=0,
interactive=True,
)
with gr.Row():
self.weight_dtype = gr.Dropdown(
label="Weight Type",
info="The precision of the model weights. Lower precision can speed up training and reduce memory, "
"with increased risk of numerical stability issues. 'bfloat16' is recommended for most use cases "
"if your GPU supports it.",
choices=get_typing_literal_options(SdTextualInversionConfig, "weight_dtype"),
interactive=True,
)
with gr.Row():
self.cache_vae_outputs = gr.Checkbox(
label="Cache VAE Outputs",
info="Cache the VAE outputs to increase speed. This should not be used when training the UNet or "
"performing data augmentations that would change the VAE outputs.",
interactive=True,
)
with gr.Row():
self.enable_cpu_offload_during_validation = gr.Checkbox(
label="Enable CPU Offload during Validation",
info="Offload models to the CPU sequentially during validation. This reduces peak VRAM "
"requirements at the cost of slower validation during training.",
interactive=True,
)
self.gradient_checkpointing = gr.Checkbox(
label="Gradient Checkpointing",
info="If True, VRAM requirements are reduced at the cost of ~20% slower training",
interactive=True,
)
gr.Markdown("## General Training Configs")
with gr.Tab("Core"):
with gr.Row():
self.lr_scheduler = gr.Dropdown(
label="Learning Rate Scheduler",
choices=get_typing_literal_options(SdTextualInversionConfig, "lr_scheduler"),
interactive=True,
)
self.lr_warmup_steps = gr.Number(
label="Warmup Steps",
info="The number of warmup steps in the "
"learning rate schedule, if applicable to the selected scheduler.",
interactive=True,
)
with gr.Row():
self.use_masks = gr.Checkbox(
label="Use Masks", info="This can only be enabled if the dataset contains masks.", interactive=True
)
with gr.Tab("Advanced"):
with gr.Column():
self.min_snr_gamma = gr.Number(
label="Minumum SNR Gamma",
info="min_snr_gamma acts like an an upper bound on the weight of samples with low noise "
"levels. If None, then Min-SNR weighting will not be applied. If enabled, the recommended "
"value is min_snr gamma = 5.0.",
interactive=True,
)
self.max_grad_norm = gr.Number(
label="Max Gradient Norm",
info="Max gradient norm for clipping. Set to 0 or leave empty for no clipping (null).",
interactive=True,
)
self.train_batch_size = gr.Number(
label="Batch Size",
info="The Training Batch Size - Higher values require increasing amounts of VRAM.",
precision=0,
interactive=True,
)
gr.Markdown("## Validation")
with gr.Group():
self.validation_prompts = gr.Textbox(
label="Validation Prompts",
info="Enter one validation prompt per line. Optionally, add negative prompts after a '[NEG]' "
"delimiter. For example: `positive prompt[NEG]negative prompt`. ",
lines=5,
interactive=True,
)
self.num_validation_images_per_prompt = gr.Number(
label="# of Validation Images to Generate per Prompt", precision=0, interactive=True
)
def update_ui_components_with_config_data(
self, config: SdTextualInversionConfig
) -> dict[gr.components.Component, typing.Any]:
update_dict = {
self.model: config.model,
self.hf_variant: config.hf_variant,
self.num_vectors: config.num_vectors,
self.placeholder_token: config.placeholder_token,
self.initializer_token: config.initializer_token,
self.initial_phrase: config.initial_phrase,
self.max_checkpoints: config.max_checkpoints,
self.lr_scheduler: config.lr_scheduler,
self.lr_warmup_steps: config.lr_warmup_steps,
self.use_masks: config.use_masks,
self.max_grad_norm: config.max_grad_norm,
self.train_batch_size: config.train_batch_size,
self.cache_vae_outputs: config.cache_vae_outputs,
self.enable_cpu_offload_during_validation: config.enable_cpu_offload_during_validation,
self.gradient_accumulation_steps: config.gradient_accumulation_steps,
self.weight_dtype: config.weight_dtype,
self.gradient_checkpointing: config.gradient_checkpointing,
self.min_snr_gamma: config.min_snr_gamma,
self.validation_prompts: convert_pos_neg_prompts_to_ui_prompts(
config.validation_prompts, config.negative_validation_prompts
),
self.num_validation_images_per_prompt: config.num_validation_images_per_prompt,
}
update_dict.update(
self.textual_inversion_sd_data_loader_config_group.update_ui_components_with_config_data(config.data_loader)
)
update_dict.update(self.base_pipeline_config_group.update_ui_components_with_config_data(config))
update_dict.update(self.optimizer_config_group.update_ui_components_with_config_data(config.optimizer))
# Sanity check to catch if we accidentally forget to update a UI component.
assert set(update_dict.keys()) == set(self.get_ui_output_components())
return update_dict
def update_config_with_ui_component_data(
self, orig_config: SdTextualInversionConfig, ui_data: dict[gr.components.Component, typing.Any]
) -> SdTextualInversionConfig:
new_config = orig_config.model_copy(deep=True)
new_config.model = ui_data.pop(self.model)
new_config.hf_variant = ui_data.pop(self.hf_variant) or None
new_config.num_vectors = ui_data.pop(self.num_vectors)
new_config.placeholder_token = ui_data.pop(self.placeholder_token)
new_config.initializer_token = ui_data.pop(self.initializer_token) or None
new_config.initial_phrase = ui_data.pop(self.initial_phrase) or None
new_config.max_checkpoints = ui_data.pop(self.max_checkpoints)
new_config.lr_scheduler = ui_data.pop(self.lr_scheduler)
new_config.lr_warmup_steps = ui_data.pop(self.lr_warmup_steps)
new_config.use_masks = ui_data.pop(self.use_masks)
max_grad_norm_value = ui_data.pop(self.max_grad_norm)
new_config.max_grad_norm = None if max_grad_norm_value == 0 else max_grad_norm_value
new_config.train_batch_size = ui_data.pop(self.train_batch_size)
new_config.cache_vae_outputs = ui_data.pop(self.cache_vae_outputs)
new_config.enable_cpu_offload_during_validation = ui_data.pop(self.enable_cpu_offload_during_validation)
new_config.gradient_accumulation_steps = ui_data.pop(self.gradient_accumulation_steps)
new_config.weight_dtype = ui_data.pop(self.weight_dtype)
new_config.gradient_checkpointing = ui_data.pop(self.gradient_checkpointing)
new_config.min_snr_gamma = ui_data.pop(self.min_snr_gamma)
new_config.num_validation_images_per_prompt = ui_data.pop(self.num_validation_images_per_prompt)
positive_prompts, negative_prompts = convert_ui_prompts_to_pos_neg_prompts(ui_data.pop(self.validation_prompts))
new_config.validation_prompts = positive_prompts
new_config.negative_validation_prompts = negative_prompts
new_config.data_loader = (
self.textual_inversion_sd_data_loader_config_group.update_config_with_ui_component_data(
new_config.data_loader, ui_data
)
)
new_config = self.base_pipeline_config_group.update_config_with_ui_component_data(new_config, ui_data)
new_config.optimizer = self.optimizer_config_group.update_config_with_ui_component_data(
new_config.optimizer, ui_data
)
# We pop items from ui_data as we use them so that we can sanity check that all the input data was transferred
# to the config.
assert len(ui_data) == 0
return new_config
================================================
FILE: src/invoke_training/ui/config_groups/sdxl_finetune_config_group.py
================================================
import typing
import gradio as gr
from invoke_training.pipelines.stable_diffusion_xl.finetune.config import SdxlFinetuneConfig
from invoke_training.ui.config_groups.base_pipeline_config_group import BasePipelineConfigGroup
from invoke_training.ui.config_groups.image_caption_sd_data_loader_config_group import (
ImageCaptionSDDataLoaderConfigGroup,
)
from invoke_training.ui.config_groups.optimizer_config_group import OptimizerConfigGroup
from invoke_training.ui.config_groups.ui_config_element import UIConfigElement
from invoke_training.ui.utils.prompts import (
convert_pos_neg_prompts_to_ui_prompts,
convert_ui_prompts_to_pos_neg_prompts,
)
from invoke_training.ui.utils.utils import get_typing_literal_options
class SdxlFinetuneConfigGroup(UIConfigElement):
def __init__(self):
"""The SDXL_FINETUNE configs."""
gr.Markdown("## Basic Configs")
with gr.Row():
with gr.Column(scale=1):
with gr.Tab("Base Model"):
self.model = gr.Textbox(
label="Model",
info="The base model. Can be a Hugging Face Hub model name, or a path to a local model (in "
"diffusers or checkpoint format).",
type="text",
interactive=True,
)
self.hf_variant = gr.Textbox(
label="Variant",
info="(optional) The Hugging Face hub model variant (e.g., fp16, fp32) to use if the model is a"
" HF Hub model name.",
type="text",
interactive=True,
)
self.vae_model = gr.Textbox(
label="VAE Model",
info="(optional) If set, this overrides the base model's default VAE model.",
type="text",
interactive=True,
)
with gr.Column(scale=3):
with gr.Tab("Training Outputs"):
self.base_pipeline_config_group = BasePipelineConfigGroup()
self.save_checkpoint_format = gr.Dropdown(
label="Checkpoint Format",
info="The save format for the checkpoints. `full_diffusers` saves the full model in diffusers "
"format. `trained_only_diffusers` saves only the parts of the model that were finetuned "
"(i.e. the UNet).",
choices=get_typing_literal_options(SdxlFinetuneConfig, "save_checkpoint_format"),
interactive=True,
)
self.save_dtype = gr.Dropdown(
label="Save Dtype",
info="The dtype to use when saving the model.",
choices=get_typing_literal_options(SdxlFinetuneConfig, "save_dtype"),
interactive=True,
)
self.max_checkpoints = gr.Number(
label="Maximum Number of Checkpoints",
info="The maximum number of checkpoints to keep on disk from this training run. Earlier "
"checkpoints will be deleted to respect this limit.",
interactive=True,
precision=0,
)
gr.Markdown("## Data Configs")
self.image_caption_sd_data_loader_config_group = ImageCaptionSDDataLoaderConfigGroup()
gr.Markdown("## Optimizer Configs")
self.optimizer_config_group = OptimizerConfigGroup()
gr.Markdown("## Speed / Memory Configs")
with gr.Group():
with gr.Row():
self.gradient_accumulation_steps = gr.Number(
label="Gradient Accumulation Steps",
info="The number of gradient steps to accumulate before each weight update. This is an alternative"
"to increasing the batch size when training with limited VRAM."
"effective_batch_size = train_batch_size * gradient_accumulation_steps.",
precision=0,
interactive=True,
)
with gr.Row():
self.weight_dtype = gr.Dropdown(
label="Weight Type",
info="The precision of the model weights. Lower precision can speed up training and reduce memory, "
"with increased risk of numerical stability issues. 'bfloat16' is recommended for most use cases "
"if your GPU supports it.",
choices=get_typing_literal_options(SdxlFinetuneConfig, "weight_dtype"),
interactive=True,
)
with gr.Row():
self.cache_text_encoder_outputs = gr.Checkbox(
label="Cache Text Encoder Outputs",
info="Cache the text encoder outputs to increase speed. This should not be used when training the "
"text encoder or performing data augmentations that would change the text encoder outputs.",
interactive=True,
)
self.cache_vae_outputs = gr.Checkbox(
label="Cache VAE Outputs",
info="Cache the VAE outputs to increase speed. This should not be used when training the UNet or "
"performing data augmentations that would change the VAE outputs.",
interactive=True,
)
with gr.Row():
self.enable_cpu_offload_during_validation = gr.Checkbox(
label="Enable CPU Offload during Validation",
info="Offload models to the CPU sequentially during validation. This reduces peak VRAM "
"requirements at the cost of slower validation during training.",
interactive=True,
)
self.gradient_checkpointing = gr.Checkbox(
label="Gradient Checkpointing",
info="If True, VRAM requirements are reduced at the cost of ~20% slower training",
interactive=True,
)
gr.Markdown("## General Training Configs")
with gr.Tab("Core"):
with gr.Row():
self.lr_scheduler = gr.Dropdown(
label="Learning Rate Scheduler",
choices=get_typing_literal_options(SdxlFinetuneConfig, "lr_scheduler"),
interactive=True,
)
self.lr_warmup_steps = gr.Number(
label="Warmup Steps",
info="The number of warmup steps in the "
"learning rate schedule, if applicable to the selected scheduler.",
interactive=True,
)
with gr.Row():
self.use_masks = gr.Checkbox(
label="Use Masks", info="This can only be enabled if the dataset contains masks.", interactive=True
)
with gr.Tab("Advanced"):
with gr.Row():
self.min_snr_gamma = gr.Number(
label="Minimum SNR Gamma",
info="min_snr_gamma acts like an an upper bound on the weight of samples with low noise "
"levels. If None, then Min-SNR weighting will not be applied. If enabled, the recommended "
"value is min_snr gamma = 5.0.",
interactive=True,
)
self.max_grad_norm = gr.Number(
label="Max Gradient Norm",
info="Max gradient norm for clipping. Set to 0 or leave empty for no clipping (null).",
interactive=True,
)
self.train_batch_size = gr.Number(
label="Batch Size",
info="The Training Batch Size - Higher values require increasing amounts of VRAM.",
precision=0,
interactive=True,
)
gr.Markdown("## Validation")
with gr.Group():
self.validation_prompts = gr.Textbox(
label="Validation Prompts",
info="Enter one validation prompt per line. Optionally, add negative prompts after a '[NEG]' "
"delimiter. For example: `positive prompt[NEG]negative prompt`. ",
lines=5,
interactive=True,
)
self.num_validation_images_per_prompt = gr.Number(
label="# of Validation Images to Generate per Prompt", precision=0, interactive=True
)
def update_ui_components_with_config_data(
self, config: SdxlFinetuneConfig
) -> dict[gr.components.Component, typing.Any]:
update_dict = {
self.model: config.model,
self.hf_variant: config.hf_variant,
self.vae_model: config.vae_model,
self.save_checkpoint_format: config.save_checkpoint_format,
self.save_dtype: config.save_dtype,
self.max_checkpoints: config.max_checkpoints,
self.lr_scheduler: config.lr_scheduler,
self.lr_warmup_steps: config.lr_warmup_steps,
self.use_masks: config.use_masks,
self.min_snr_gamma: config.min_snr_gamma,
self.max_grad_norm: config.max_grad_norm,
self.train_batch_size: config.train_batch_size,
self.cache_text_encoder_outputs: config.cache_text_encoder_outputs,
self.cache_vae_outputs: config.cache_vae_outputs,
self.enable_cpu_offload_during_validation: config.enable_cpu_offload_during_validation,
self.gradient_accumulation_steps: config.gradient_accumulation_steps,
self.weight_dtype: config.weight_dtype,
self.gradient_checkpointing: config.gradient_checkpointing,
self.validation_prompts: convert_pos_neg_prompts_to_ui_prompts(
config.validation_prompts, config.negative_validation_prompts
),
self.num_validation_images_per_prompt: config.num_validation_images_per_prompt,
}
update_dict.update(
self.image_caption_sd_data_loader_config_group.update_ui_components_with_config_data(config.data_loader)
)
update_dict.update(self.base_pipeline_config_group.update_ui_components_with_config_data(config))
update_dict.update(self.optimizer_config_group.update_ui_components_with_config_data(config.optimizer))
# Sanity check to catch if we accidentally forget to update a UI component.
assert set(update_dict.keys()) == set(self.get_ui_output_components())
return update_dict
def update_config_with_ui_component_data(
self, orig_config: SdxlFinetuneConfig, ui_data: dict[gr.components.Component, typing.Any]
) -> SdxlFinetuneConfig:
new_config = orig_config.model_copy(deep=True)
new_config.model = ui_data.pop(self.model)
new_config.hf_variant = ui_data.pop(self.hf_variant) or None
new_config.vae_model = ui_data.pop(self.vae_model) or None
new_config.save_checkpoint_format = ui_data.pop(self.save_checkpoint_format)
new_config.save_dtype = ui_data.pop(self.save_dtype)
new_config.max_checkpoints = ui_data.pop(self.max_checkpoints)
new_config.lr_scheduler = ui_data.pop(self.lr_scheduler)
new_config.lr_warmup_steps = ui_data.pop(self.lr_warmup_steps)
new_config.use_masks = ui_data.pop(self.use_masks)
new_config.min_snr_gamma = ui_data.pop(self.min_snr_gamma)
max_grad_norm_value = ui_data.pop(self.max_grad_norm)
new_config.max_grad_norm = None if max_grad_norm_value == 0 else max_grad_norm_value
new_config.train_batch_size = ui_data.pop(self.train_batch_size)
new_config.cache_text_encoder_outputs = ui_data.pop(self.cache_text_encoder_outputs)
new_config.cache_vae_outputs = ui_data.pop(self.cache_vae_outputs)
new_config.enable_cpu_offload_during_validation = ui_data.pop(self.enable_cpu_offload_during_validation)
new_config.gradient_accumulation_steps = ui_data.pop(self.gradient_accumulation_steps)
new_config.weight_dtype = ui_data.pop(self.weight_dtype)
new_config.gradient_checkpointing = ui_data.pop(self.gradient_checkpointing)
new_config.num_validation_images_per_prompt = ui_data.pop(self.num_validation_images_per_prompt)
positive_prompts, negative_prompts = convert_ui_prompts_to_pos_neg_prompts(ui_data.pop(self.validation_prompts))
new_config.validation_prompts = positive_prompts
new_config.negative_validation_prompts = negative_prompts
new_config.data_loader = self.image_caption_sd_data_loader_config_group.update_config_with_ui_component_data(
new_config.data_loader, ui_data
)
new_config = self.base_pipeline_config_group.update_config_with_ui_component_data(new_config, ui_data)
new_config.optimizer = self.optimizer_config_group.update_config_with_ui_component_data(
new_config.optimizer, ui_data
)
# We pop items from ui_data as we use them so that we can sanity check that all the input data was transferred
# to the config.
assert len(ui_data) == 0
return new_config
================================================
FILE: src/invoke_training/ui/config_groups/sdxl_lora_and_textual_inversion_config_group.py
================================================
import typing
import gradio as gr
from invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config import (
SdxlLoraAndTextualInversionConfig,
)
from invoke_training.ui.config_groups.base_pipeline_config_group import BasePipelineConfigGroup
from invoke_training.ui.config_groups.optimizer_config_group import OptimizerConfigGroup
from invoke_training.ui.config_groups.textual_inversion_sd_data_loader_config_group import (
TextualInversionSDDataLoaderConfigGroup,
)
from invoke_training.ui.config_groups.ui_config_element import UIConfigElement
from invoke_training.ui.utils.prompts import (
convert_pos_neg_prompts_to_ui_prompts,
convert_ui_prompts_to_pos_neg_prompts,
)
from invoke_training.ui.utils.utils import get_typing_literal_options
class SdxlLoraAndTextualInversionConfigGroup(UIConfigElement):
def __init__(self):
"""The SDXL_LORA_AND_TEXTUAL_INVERSION configs."""
gr.Markdown("## Basic Configs")
with gr.Row():
with gr.Column(scale=1):
with gr.Tab("Base Model"):
self.model = gr.Textbox(
label="Model",
info="The base model. Can be a Hugging Face Hub model name, or a path to a local model (in "
"diffusers or checkpoint format).",
type="text",
interactive=True,
)
self.hf_variant = gr.Textbox(
label="Variant",
info="(optional) The Hugging Face hub model variant (e.g., fp16, fp32) to use if the model is a"
" HF Hub model name.",
type="text",
interactive=True,
)
self.vae_model = gr.Textbox(
label="VAE Model",
info="(optional) If set, this overrides the base model's default VAE model.",
type="text",
interactive=True,
)
with gr.Column(scale=3):
with gr.Tab("Training Outputs"):
self.base_pipeline_config_group = BasePipelineConfigGroup()
self.max_checkpoints = gr.Number(
label="Maximum Number of Checkpoints",
info="The maximum number of checkpoints to keep on disk from this training run. Earlier "
"checkpoints will be deleted to respect this limit.",
interactive=True,
precision=0,
)
gr.Markdown("## Data Configs")
self.image_caption_sd_data_loader_config_group = TextualInversionSDDataLoaderConfigGroup()
gr.Markdown("## Textual Inversion Configs")
self.num_vectors = gr.Number(
label="Num Vectors",
info="The number of TI vectors that will be trained. Can be overriden by 'Initial Phrase'.",
interactive=True,
precision=0,
)
self.placeholder_token = gr.Textbox(
label="Placeholder Token",
info="The special word to associate the learned embeddings with. Choose a unique token that is unlikely to "
"already exist in the tokenizer's vocabulary.",
interactive=True,
)
self.initializer_token = gr.Textbox(
label="Initializer Token",
info="Only one of 'Initializer Token' or 'Initial Phrase' should be set. A vocabulary token to use as an "
"initializer for the placeholder token. It should be a single word that roughly describes the object or "
"style that you're trying to train on. The initializer token ust map to a single tokenizer token.",
interactive=True,
)
self.initial_phrase = gr.Textbox(
label="Initial Phrase",
info="Only one of 'Initializer Token' or 'Initial Phrase' should be set. A phrase that will be used to "
"initialize the placeholder token embedding. The phrase will be tokenized, and the corresponding "
"embeddings will be used to initialize the placeholder tokens. The number of embedding vectors will be "
"inferred from the length of the tokenized phrase, so keep the phrase short.",
interactive=True,
)
gr.Markdown("## Optimizer Configs")
self.optimizer_config_group = OptimizerConfigGroup()
gr.Markdown("## Speed / Memory Configs")
with gr.Group():
with gr.Row():
self.gradient_accumulation_steps = gr.Number(
label="Gradient Accumulation Steps",
info="The number of gradient steps to accumulate before each weight update. This is an alternative"
"to increasing the batch size when training with limited VRAM."
"effective_batch_size = train_batch_size * gradient_accumulation_steps.",
precision=0,
interactive=True,
)
with gr.Row():
self.weight_dtype = gr.Dropdown(
label="Weight Type",
info="The precision of the model weights. Lower precision can speed up training and reduce memory, "
"with increased risk of numerical stability issues. 'bfloat16' is recommended for most use cases "
"if your GPU supports it.",
choices=get_typing_literal_options(SdxlLoraAndTextualInversionConfig, "weight_dtype"),
interactive=True,
)
with gr.Row():
self.cache_text_encoder_outputs = gr.Checkbox(
label="Cache Text Encoder Outputs",
info="Cache the text encoder outputs to increase speed. This should not be used when training the "
"text encoder or performing data augmentations that would change the text encoder outputs.",
interactive=True,
)
self.cache_vae_outputs = gr.Checkbox(
label="Cache VAE Outputs",
info="Cache the VAE outputs to increase speed. This should not be used when training the UNet or "
"performing data augmentations that would change the VAE outputs.",
interactive=True,
)
with gr.Row():
self.enable_cpu_offload_during_validation = gr.Checkbox(
label="Enable CPU Offload during Validation",
info="Offload models to the CPU sequentially during validation. This reduces peak VRAM "
"requirements at the cost of slower validation during training.",
interactive=True,
)
self.gradient_checkpointing = gr.Checkbox(
label="Gradient Checkpointing",
info="If True, VRAM requirements are reduced at the cost of ~20% slower training",
interactive=True,
)
gr.Markdown("## General Training Configs")
with gr.Tab("Core"):
with gr.Row():
self.train_unet = gr.Checkbox(label="Train UNet", interactive=True)
self.train_text_encoder = gr.Checkbox(label="Train Text Encoder", interactive=True)
self.train_ti = gr.Checkbox(label="Train Textual Inversion Token", scale=2, interactive=True)
with gr.Row():
self.unet_learning_rate = gr.Number(
label="UNet Learning Rate",
info="The UNet learning rate. Set to 0 or leave empty to inherit from the base optimizer "
"learning rate.",
interactive=True,
)
self.text_encoder_learning_rate = gr.Number(
label="Text Encoder Learning Rate",
info="The text encoder learning rate. Set to 0 or leave empty to inherit from the base optimizer "
"learning rate.",
interactive=True,
)
self.textual_inversion_learning_rate = gr.Number(
label="Textual Inversion Learning Rate",
info="The textual inversion learning rate. Set to 0 or leave empty to inherit from the base "
"optimizer learning rate.",
interactive=True,
)
self.ti_train_steps_ratio = gr.Number(label="Textual Inversion Train Steps Ratio", interactive=True)
with gr.Row():
self.lr_scheduler = gr.Dropdown(
label="Learning Rate Scheduler",
choices=get_typing_literal_options(SdxlLoraAndTextualInversionConfig, "lr_scheduler"),
interactive=True,
)
self.lr_warmup_steps = gr.Number(
label="Warmup Steps",
info="The number of warmup steps in the "
"learning rate schedule, if applicable to the selected scheduler.",
interactive=True,
)
with gr.Row():
self.use_masks = gr.Checkbox(
label="Use Masks", info="This can only be enabled if the dataset contains masks.", interactive=True
)
with gr.Tab("Advanced"):
with gr.Column():
self.lora_rank_dim = gr.Number(
label="LoRA Rank Dim",
info="The rank dimension to use for the LoRA layers. Increasing the rank dimension increases"
" the model's expressivity, but also increases the size of the generated LoRA model.",
interactive=True,
precision=0,
)
self.min_snr_gamma = gr.Number(
label="Minumum SNR Gamma",
info="min_snr_gamma acts like an an upper bound on the weight of samples with low noise "
"levels. If None, then Min-SNR weighting will not be applied. If enabled, the recommended "
"value is min_snr gamma = 5.0.",
interactive=True,
)
self.max_grad_norm = gr.Number(
label="Max Gradient Norm",
info="Max gradient norm for clipping. Set to 0 or leave empty for no clipping (null).",
interactive=True,
)
self.train_batch_size = gr.Number(
label="Batch Size",
info="The Training Batch Size - Higher values require increasing amounts of VRAM.",
precision=0,
interactive=True,
)
gr.Markdown("## Validation")
with gr.Group():
self.validation_prompts = gr.Textbox(
label="Validation Prompts",
info="Enter one validation prompt per line. Optionally, add negative prompts after a '[NEG]' "
"delimiter. For example: `positive prompt[NEG]negative prompt`. ",
lines=5,
interactive=True,
)
self.num_validation_images_per_prompt = gr.Number(
label="# of Validation Images to Generate per Prompt", precision=0, interactive=True
)
def update_ui_components_with_config_data(
self, config: SdxlLoraAndTextualInversionConfig
) -> dict[gr.components.Component, typing.Any]:
update_dict = {
self.model: config.model,
self.hf_variant: config.hf_variant,
self.vae_model: config.vae_model,
self.num_vectors: config.num_vectors,
self.placeholder_token: config.placeholder_token,
self.initializer_token: config.initializer_token,
self.initial_phrase: config.initial_phrase,
self.max_checkpoints: config.max_checkpoints,
self.train_unet: config.train_unet,
self.train_text_encoder: config.train_text_encoder,
self.train_ti: config.train_ti,
self.unet_learning_rate: config.unet_learning_rate,
self.text_encoder_learning_rate: config.text_encoder_learning_rate,
self.textual_inversion_learning_rate: config.textual_inversion_learning_rate,
self.ti_train_steps_ratio: config.ti_train_steps_ratio,
self.lr_scheduler: config.lr_scheduler,
self.lr_warmup_steps: config.lr_warmup_steps,
self.use_masks: config.use_masks,
self.max_grad_norm: config.max_grad_norm,
self.train_batch_size: config.train_batch_size,
self.cache_text_encoder_outputs: config.cache_text_encoder_outputs,
self.cache_vae_outputs: config.cache_vae_outputs,
self.enable_cpu_offload_during_validation: config.enable_cpu_offload_during_validation,
self.gradient_accumulation_steps: config.gradient_accumulation_steps,
self.weight_dtype: config.weight_dtype,
self.gradient_checkpointing: config.gradient_checkpointing,
self.lora_rank_dim: config.lora_rank_dim,
self.min_snr_gamma: config.min_snr_gamma,
self.validation_prompts: convert_pos_neg_prompts_to_ui_prompts(
config.validation_prompts, config.negative_validation_prompts
),
self.num_validation_images_per_prompt: config.num_validation_images_per_prompt,
}
update_dict.update(
self.image_caption_sd_data_loader_config_group.update_ui_components_with_config_data(config.data_loader)
)
update_dict.update(self.base_pipeline_config_group.update_ui_components_with_config_data(config))
update_dict.update(self.optimizer_config_group.update_ui_components_with_config_data(config.optimizer))
# Sanity check to catch if we accidentally forget to update a UI component.
assert set(update_dict.keys()) == set(self.get_ui_output_components())
return update_dict
def update_config_with_ui_component_data(
self, orig_config: SdxlLoraAndTextualInversionConfig, ui_data: dict[gr.components.Component, typing.Any]
) -> SdxlLoraAndTextualInversionConfig:
new_config = orig_config.model_copy(deep=True)
new_config.model = ui_data.pop(self.model)
new_config.hf_variant = ui_data.pop(self.hf_variant) or None
new_config.vae_model = ui_data.pop(self.vae_model) or None
new_config.num_vectors = ui_data.pop(self.num_vectors)
new_config.placeholder_token = ui_data.pop(self.placeholder_token)
new_config.initializer_token = ui_data.pop(self.initializer_token) or None
new_config.initial_phrase = ui_data.pop(self.initial_phrase) or None
new_config.max_checkpoints = ui_data.pop(self.max_checkpoints)
new_config.train_unet = ui_data.pop(self.train_unet)
new_config.train_text_encoder = ui_data.pop(self.train_text_encoder)
new_config.train_ti = ui_data.pop(self.train_ti)
unet_lr_value = ui_data.pop(self.unet_learning_rate)
new_config.unet_learning_rate = None if unet_lr_value == 0 else unet_lr_value
text_encoder_lr_value = ui_data.pop(self.text_encoder_learning_rate)
new_config.text_encoder_learning_rate = None if text_encoder_lr_value == 0 else text_encoder_lr_value
ti_lr_value = ui_data.pop(self.textual_inversion_learning_rate)
new_config.textual_inversion_learning_rate = None if ti_lr_value == 0 else ti_lr_value
new_config.ti_train_steps_ratio = ui_data.pop(self.ti_train_steps_ratio)
new_config.lr_scheduler = ui_data.pop(self.lr_scheduler)
new_config.lr_warmup_steps = ui_data.pop(self.lr_warmup_steps)
new_config.use_masks = ui_data.pop(self.use_masks)
max_grad_norm_value = ui_data.pop(self.max_grad_norm)
new_config.max_grad_norm = None if max_grad_norm_value == 0 else max_grad_norm_value
new_config.train_batch_size = ui_data.pop(self.train_batch_size)
new_config.cache_text_encoder_outputs = ui_data.pop(self.cache_text_encoder_outputs)
new_config.cache_vae_outputs = ui_data.pop(self.cache_vae_outputs)
new_config.enable_cpu_offload_during_validation = ui_data.pop(self.enable_cpu_offload_during_validation)
new_config.gradient_accumulation_steps = ui_data.pop(self.gradient_accumulation_steps)
new_config.weight_dtype = ui_data.pop(self.weight_dtype)
new_config.gradient_checkpointing = ui_data.pop(self.gradient_checkpointing)
new_config.lora_rank_dim = ui_data.pop(self.lora_rank_dim)
new_config.min_snr_gamma = ui_data.pop(self.min_snr_gamma)
new_config.num_validation_images_per_prompt = ui_data.pop(self.num_validation_images_per_prompt)
positive_prompts, negative_prompts = convert_ui_prompts_to_pos_neg_prompts(ui_data.pop(self.validation_prompts))
new_config.validation_prompts = positive_prompts
new_config.negative_validation_prompts = negative_prompts
new_config.data_loader = self.image_caption_sd_data_loader_config_group.update_config_with_ui_component_data(
new_config.data_loader, ui_data
)
new_config = self.base_pipeline_config_group.update_config_with_ui_component_data(new_config, ui_data)
new_config.optimizer = self.optimizer_config_group.update_config_with_ui_component_data(
new_config.optimizer, ui_data
)
# We pop items from ui_data as we use them so that we can sanity check that all the input data was transferred
# to the config.
assert len(ui_data) == 0
return new_config
================================================
FILE: src/invoke_training/ui/config_groups/sdxl_lora_config_group.py
================================================
import typing
import gradio as gr
from invoke_training.pipelines.stable_diffusion_xl.lora.config import SdxlLoraConfig
from invoke_training.ui.config_groups.base_pipeline_config_group import BasePipelineConfigGroup
from invoke_training.ui.config_groups.image_caption_sd_data_loader_config_group import (
ImageCaptionSDDataLoaderConfigGroup,
)
from invoke_training.ui.config_groups.optimizer_config_group import OptimizerConfigGroup
from invoke_training.ui.config_groups.ui_config_element import UIConfigElement
from invoke_training.ui.utils.prompts import (
convert_pos_neg_prompts_to_ui_prompts,
convert_ui_prompts_to_pos_neg_prompts,
)
from invoke_training.ui.utils.utils import get_typing_literal_options
class SdxlLoraConfigGroup(UIConfigElement):
def __init__(self):
"""The SD_LORA configs."""
gr.Markdown("## Basic Configs")
with gr.Row():
with gr.Column(scale=1):
with gr.Tab("Base Model"):
self.model = gr.Textbox(
label="Model",
info="The base model. Can be a Hugging Face Hub model name, or a path to a local model (in "
"diffusers or checkpoint format).",
type="text",
interactive=True,
)
self.hf_variant = gr.Textbox(
label="Variant",
info="(optional) The Hugging Face hub model variant (e.g., fp16, fp32) to use if the model is a"
" HF Hub model name.",
type="text",
interactive=True,
)
self.vae_model = gr.Textbox(
label="VAE Model",
info="(optional) If set, this overrides the base model's default VAE model.",
type="text",
interactive=True,
)
with gr.Column(scale=3):
with gr.Tab("Training Outputs"):
self.base_pipeline_config_group = BasePipelineConfigGroup()
self.max_checkpoints = gr.Number(
label="Maximum Number of Checkpoints",
info="The maximum number of checkpoints to keep on disk from this training run. Earlier "
"checkpoints will be deleted to respect this limit.",
interactive=True,
precision=0,
)
gr.Markdown("## Data Configs")
self.image_caption_sd_data_loader_config_group = ImageCaptionSDDataLoaderConfigGroup()
gr.Markdown("## Optimizer Configs")
self.optimizer_config_group = OptimizerConfigGroup()
gr.Markdown("## Speed / Memory Configs")
with gr.Group():
with gr.Row():
self.gradient_accumulation_steps = gr.Number(
label="Gradient Accumulation Steps",
info="The number of gradient steps to accumulate before each weight update. This is an alternative"
"to increasing the batch size when training with limited VRAM."
"effective_batch_size = train_batch_size * gradient_accumulation_steps.",
precision=0,
interactive=True,
)
with gr.Row():
self.weight_dtype = gr.Dropdown(
label="Weight Type",
info="The precision of the model weights. Lower precision can speed up training and reduce memory, "
"with increased risk of numerical stability issues. 'bfloat16' is recommended for most use cases "
"if your GPU supports it.",
choices=get_typing_literal_options(SdxlLoraConfig, "weight_dtype"),
interactive=True,
)
with gr.Row():
self.cache_text_encoder_outputs = gr.Checkbox(
label="Cache Text Encoder Outputs",
info="Cache the text encoder outputs to increase speed. This should not be used when training the "
"text encoder or performing data augmentations that would change the text encoder outputs.",
interactive=True,
)
self.cache_vae_outputs = gr.Checkbox(
label="Cache VAE Outputs",
info="Cache the VAE outputs to increase speed. This should not be used when training the UNet or "
"performing data augmentations that would change the VAE outputs.",
interactive=True,
)
with gr.Row():
self.enable_cpu_offload_during_validation = gr.Checkbox(
label="Enable CPU Offload during Validation",
info="Offload models to the CPU sequentially during validation. This reduces peak VRAM "
"requirements at the cost of slower validation during training.",
interactive=True,
)
self.gradient_checkpointing = gr.Checkbox(
label="Gradient Checkpointing",
info="If True, VRAM requirements are reduced at the cost of ~20% slower training",
interactive=True,
)
gr.Markdown("## General Training Configs")
with gr.Tab("Core"):
with gr.Row():
self.train_unet = gr.Checkbox(label="Train UNet", interactive=True)
self.train_text_encoder = gr.Checkbox(label="Train Text Encoder", interactive=True)
with gr.Row():
self.unet_learning_rate = gr.Number(
label="UNet Learning Rate",
info="The UNet learning rate. Set to 0 or leave empty to inherit from the base optimizer "
"learning rate.",
interactive=True,
)
self.text_encoder_learning_rate = gr.Number(
label="Text Encoder Learning Rate",
info="The text encoder learning rate. Set to 0 or leave empty to inherit from the base optimizer "
"learning rate.",
interactive=True,
)
with gr.Row():
self.lr_scheduler = gr.Dropdown(
label="Learning Rate Scheduler",
choices=get_typing_literal_options(SdxlLoraConfig, "lr_scheduler"),
interactive=True,
)
self.lr_warmup_steps = gr.Number(
label="Warmup Steps",
info="The number of warmup steps in the "
"learning rate schedule, if applicable to the selected scheduler.",
interactive=True,
)
with gr.Row():
self.use_masks = gr.Checkbox(
label="Use Masks", info="This can only be enabled if the dataset contains masks.", interactive=True
)
with gr.Tab("Advanced"):
with gr.Column():
self.lora_rank_dim = gr.Number(
label="LoRA Rank Dim",
info="The rank dimension to use for the LoRA layers. Increasing the rank dimension increases"
" the model's expressivity, but also increases the size of the generated LoRA model.",
interactive=True,
precision=0,
)
self.min_snr_gamma = gr.Number(
label="Minumum SNR Gamma",
info="min_snr_gamma acts like an an upper bound on the weight of samples with low noise "
"levels. If None, then Min-SNR weighting will not be applied. If enabled, the recommended "
"value is min_snr gamma = 5.0.",
interactive=True,
)
self.max_grad_norm = gr.Number(
label="Max Gradient Norm",
info="Max gradient norm for clipping. Set to 0 or leave empty for no clipping (null).",
interactive=True,
)
self.train_batch_size = gr.Number(
label="Batch Size",
info="The Training Batch Size - Higher values require increasing amounts of VRAM.",
precision=0,
interactive=True,
)
gr.Markdown("## Validation")
with gr.Group():
self.validation_prompts = gr.Textbox(
label="Validation Prompts",
info="Enter one validation prompt per line. Optionally, add negative prompts after a '[NEG]' "
"delimiter. For example: `positive prompt[NEG]negative prompt`. ",
lines=5,
interactive=True,
)
self.num_validation_images_per_prompt = gr.Number(
label="# of Validation Images to Generate per Prompt", precision=0, interactive=True
)
def update_ui_components_with_config_data(
self, config: SdxlLoraConfig
) -> dict[gr.components.Component, typing.Any]:
update_dict = {
self.model: config.model,
self.hf_variant: config.hf_variant,
self.vae_model: config.vae_model,
self.max_checkpoints: config.max_checkpoints,
self.train_unet: config.train_unet,
self.unet_learning_rate: config.unet_learning_rate,
self.train_text_encoder: config.train_text_encoder,
self.text_encoder_learning_rate: config.text_encoder_learning_rate,
self.lr_scheduler: config.lr_scheduler,
self.lr_warmup_steps: config.lr_warmup_steps,
self.use_masks: config.use_masks,
self.max_grad_norm: config.max_grad_norm,
self.train_batch_size: config.train_batch_size,
self.cache_text_encoder_outputs: config.cache_text_encoder_outputs,
self.cache_vae_outputs: config.cache_vae_outputs,
self.enable_cpu_offload_during_validation: config.enable_cpu_offload_during_validation,
self.gradient_accumulation_steps: config.gradient_accumulation_steps,
self.weight_dtype: config.weight_dtype,
self.gradient_checkpointing: config.gradient_checkpointing,
self.lora_rank_dim: config.lora_rank_dim,
self.min_snr_gamma: config.min_snr_gamma,
self.validation_prompts: convert_pos_neg_prompts_to_ui_prompts(
config.validation_prompts, config.negative_validation_prompts
),
self.num_validation_images_per_prompt: config.num_validation_images_per_prompt,
}
update_dict.update(
self.image_caption_sd_data_loader_config_group.update_ui_components_with_config_data(config.data_loader)
)
update_dict.update(self.base_pipeline_config_group.update_ui_components_with_config_data(config))
update_dict.update(self.optimizer_config_group.update_ui_components_with_config_data(config.optimizer))
# Sanity check to catch if we accidentally forget to update a UI component.
assert set(update_dict.keys()) == set(self.get_ui_output_components())
return update_dict
def update_config_with_ui_component_data(
self, orig_config: SdxlLoraConfig, ui_data: dict[gr.components.Component, typing.Any]
) -> SdxlLoraConfig:
new_config = orig_config.model_copy(deep=True)
new_config.model = ui_data.pop(self.model)
new_config.hf_variant = ui_data.pop(self.hf_variant) or None
new_config.vae_model = ui_data.pop(self.vae_model) or None
new_config.max_checkpoints = ui_data.pop(self.max_checkpoints)
new_config.train_unet = ui_data.pop(self.train_unet)
unet_lr_value = ui_data.pop(self.unet_learning_rate)
new_config.unet_learning_rate = None if unet_lr_value == 0 else unet_lr_value
new_config.train_text_encoder = ui_data.pop(self.train_text_encoder)
text_encoder_lr_value = ui_data.pop(self.text_encoder_learning_rate)
new_config.text_encoder_learning_rate = None if text_encoder_lr_value == 0 else text_encoder_lr_value
new_config.lr_scheduler = ui_data.pop(self.lr_scheduler)
new_config.lr_warmup_steps = ui_data.pop(self.lr_warmup_steps)
new_config.use_masks = ui_data.pop(self.use_masks)
max_grad_norm_value = ui_data.pop(self.max_grad_norm)
new_config.max_grad_norm = None if max_grad_norm_value == 0 else max_grad_norm_value
new_config.train_batch_size = ui_data.pop(self.train_batch_size)
new_config.cache_text_encoder_outputs = ui_data.pop(self.cache_text_encoder_outputs)
new_config.cache_vae_outputs = ui_data.pop(self.cache_vae_outputs)
new_config.enable_cpu_offload_during_validation = ui_data.pop(self.enable_cpu_offload_during_validation)
new_config.gradient_accumulation_steps = ui_data.pop(self.gradient_accumulation_steps)
new_config.weight_dtype = ui_data.pop(self.weight_dtype)
new_config.gradient_checkpointing = ui_data.pop(self.gradient_checkpointing)
new_config.lora_rank_dim = ui_data.pop(self.lora_rank_dim)
new_config.min_snr_gamma = ui_data.pop(self.min_snr_gamma)
new_config.num_validation_images_per_prompt = ui_data.pop(self.num_validation_images_per_prompt)
positive_prompts, negative_prompts = convert_ui_prompts_to_pos_neg_prompts(ui_data.pop(self.validation_prompts))
new_config.validation_prompts = positive_prompts
new_config.negative_validation_prompts = negative_prompts
new_config.data_loader = self.image_caption_sd_data_loader_config_group.update_config_with_ui_component_data(
new_config.data_loader, ui_data
)
new_config = self.base_pipeline_config_group.update_config_with_ui_component_data(new_config, ui_data)
new_config.optimizer = self.optimizer_config_group.update_config_with_ui_component_data(
new_config.optimizer, ui_data
)
# We pop items from ui_data as we use them so that we can sanity check that all the input data was transferred
# to the config.
assert len(ui_data) == 0
return new_config
================================================
FILE: src/invoke_training/ui/config_groups/sdxl_textual_inversion_config_group.py
================================================
import typing
import gradio as gr
from invoke_training.pipelines.stable_diffusion_xl.textual_inversion.config import SdxlTextualInversionConfig
from invoke_training.ui.config_groups.base_pipeline_config_group import BasePipelineConfigGroup
from invoke_training.ui.config_groups.optimizer_config_group import OptimizerConfigGroup
from invoke_training.ui.config_groups.textual_inversion_sd_data_loader_config_group import (
TextualInversionSDDataLoaderConfigGroup,
)
from invoke_training.ui.config_groups.ui_config_element import UIConfigElement
from invoke_training.ui.utils.prompts import (
convert_pos_neg_prompts_to_ui_prompts,
convert_ui_prompts_to_pos_neg_prompts,
)
from invoke_training.ui.utils.utils import get_typing_literal_options
class SdxlTextualInversionConfigGroup(UIConfigElement):
def __init__(self):
"""The SDXL_TEXTUAL_INVERSION configs."""
gr.Markdown("## Basic Configs")
with gr.Row():
with gr.Column(scale=1):
with gr.Tab("Base Model"):
self.model = gr.Textbox(
label="Model",
info="The base model. Can be a Hugging Face Hub model name, or a path to a local model (in "
"diffusers or checkpoint format).",
type="text",
interactive=True,
)
self.hf_variant = gr.Textbox(
label="Variant",
info="(optional) The Hugging Face hub model variant (e.g., fp16, fp32) to use if the model is a"
" HF Hub model name.",
type="text",
interactive=True,
)
self.vae_model = gr.Textbox(
label="VAE Model",
info="(optional) If set, this overrides the base model's default VAE model.",
type="text",
interactive=True,
)
with gr.Column(scale=3):
with gr.Tab("Training Outputs"):
self.base_pipeline_config_group = BasePipelineConfigGroup()
self.max_checkpoints = gr.Number(
label="Maximum Number of Checkpoints",
info="The maximum number of checkpoints to keep on disk from this training run. Earlier "
"checkpoints will be deleted to respect this limit.",
interactive=True,
precision=0,
)
gr.Markdown("## Data Configs")
self.textual_inversion_sd_data_loader_config_group = TextualInversionSDDataLoaderConfigGroup()
gr.Markdown("## Textual Inversion Configs")
self.num_vectors = gr.Number(
label="Num Vectors",
info="The number of TI vectors that will be trained. Can be overriden by 'Initial Phrase'.",
interactive=True,
precision=0,
)
self.placeholder_token = gr.Textbox(
label="Placeholder Token",
info="The special word to associate the learned embeddings with. Choose a unique token that is unlikely to "
"already exist in the tokenizer's vocabulary.",
interactive=True,
)
self.initializer_token = gr.Textbox(
label="Initializer Token",
info="Only one of 'Initializer Token' or 'Initial Phrase' should be set. A vocabulary token to use as an "
"initializer for the placeholder token. It should be a single word that roughly describes the object or "
"style that you're trying to train on. The initializer token ust map to a single tokenizer token.",
interactive=True,
)
self.initial_phrase = gr.Textbox(
label="Initial Phrase",
info="Only one of 'Initializer Token' or 'Initial Phrase' should be set. A phrase that will be used to "
"initialize the placeholder token embedding. The phrase will be tokenized, and the corresponding "
"embeddings will be used to initialize the placeholder tokens. The number of embedding vectors will be "
"inferred from the length of the tokenized phrase, so keep the phrase short.",
interactive=True,
)
gr.Markdown("## Optimizer Configs")
self.optimizer_config_group = OptimizerConfigGroup()
gr.Markdown("## Speed / Memory Configs")
with gr.Group():
with gr.Row():
self.gradient_accumulation_steps = gr.Number(
label="Gradient Accumulation Steps",
info="The number of gradient steps to accumulate before each weight update. This is an"
" alternative to increasing the batch size when training with limited VRAM."
"effective_batch_size = train_batch_size * gradient_accumulation_steps.",
precision=0,
interactive=True,
)
with gr.Row():
self.weight_dtype = gr.Dropdown(
label="Weight Type",
info="The precision of the model weights. Lower precision can speed up training and reduce memory, "
"with increased risk of numerical stability issues. 'bfloat16' is recommended for most use cases "
"if your GPU supports it.",
choices=get_typing_literal_options(SdxlTextualInversionConfig, "weight_dtype"),
interactive=True,
)
with gr.Row():
self.cache_vae_outputs = gr.Checkbox(
label="Cache VAE Outputs",
info="Cache the VAE outputs to increase speed. This should not be used when training the UNet or "
"performing data augmentations that would change the VAE outputs.",
interactive=True,
)
with gr.Row():
self.enable_cpu_offload_during_validation = gr.Checkbox(
label="Enable CPU Offload during Validation",
info="Offload models to the CPU sequentially during validation. This reduces peak VRAM "
"requirements at the cost of slower validation during training.",
interactive=True,
)
self.gradient_checkpointing = gr.Checkbox(
label="Gradient Checkpointing",
info="If True, VRAM requirements are reduced at the cost of ~20% slower training",
interactive=True,
)
gr.Markdown("## General Training Configs")
with gr.Tab("Core"):
with gr.Row():
self.lr_scheduler = gr.Dropdown(
label="Learning Rate Scheduler",
choices=get_typing_literal_options(SdxlTextualInversionConfig, "lr_scheduler"),
interactive=True,
)
self.lr_warmup_steps = gr.Number(
label="Warmup Steps",
info="The number of warmup steps in the "
"learning rate schedule, if applicable to the selected scheduler.",
interactive=True,
)
with gr.Row():
self.use_masks = gr.Checkbox(
label="Use Masks", info="This can only be enabled if the dataset contains masks.", interactive=True
)
with gr.Tab("Advanced"):
with gr.Column():
self.min_snr_gamma = gr.Number(
label="Minumum SNR Gamma",
info="min_snr_gamma acts like an an upper bound on the weight of samples with low noise "
"levels. If None, then Min-SNR weighting will not be applied. If enabled, the recommended "
"value is min_snr gamma = 5.0.",
interactive=True,
)
self.max_grad_norm = gr.Number(
label="Max Gradient Norm",
info="Max gradient norm for clipping. Set to 0 or leave empty for no clipping (null).",
interactive=True,
)
self.train_batch_size = gr.Number(
label="Batch Size",
info="The Training Batch Size - Higher values require increasing amounts of VRAM.",
precision=0,
interactive=True,
)
gr.Markdown("## Validation")
with gr.Group():
self.validation_prompts = gr.Textbox(
label="Validation Prompts",
info="Enter one validation prompt per line. Optionally, add negative prompts after a '[NEG]' "
"delimiter. For example: `positive prompt[NEG]negative prompt`. ",
lines=5,
interactive=True,
)
self.num_validation_images_per_prompt = gr.Number(
label="# of Validation Images to Generate per Prompt", precision=0, interactive=True
)
def update_ui_components_with_config_data(
self, config: SdxlTextualInversionConfig
) -> dict[gr.components.Component, typing.Any]:
update_dict = {
self.model: config.model,
self.hf_variant: config.hf_variant,
self.vae_model: config.vae_model,
self.num_vectors: config.num_vectors,
self.placeholder_token: config.placeholder_token,
self.initializer_token: config.initializer_token,
self.initial_phrase: config.initial_phrase,
self.max_checkpoints: config.max_checkpoints,
self.lr_scheduler: config.lr_scheduler,
self.lr_warmup_steps: config.lr_warmup_steps,
self.use_masks: config.use_masks,
self.max_grad_norm: config.max_grad_norm,
self.train_batch_size: config.train_batch_size,
self.cache_vae_outputs: config.cache_vae_outputs,
self.enable_cpu_offload_during_validation: config.enable_cpu_offload_during_validation,
self.gradient_accumulation_steps: config.gradient_accumulation_steps,
self.weight_dtype: config.weight_dtype,
self.gradient_checkpointing: config.gradient_checkpointing,
self.min_snr_gamma: config.min_snr_gamma,
self.validation_prompts: convert_pos_neg_prompts_to_ui_prompts(
config.validation_prompts, config.negative_validation_prompts
),
self.num_validation_images_per_prompt: config.num_validation_images_per_prompt,
}
update_dict.update(
self.textual_inversion_sd_data_loader_config_group.update_ui_components_with_config_data(config.data_loader)
)
update_dict.update(self.base_pipeline_config_group.update_ui_components_with_config_data(config))
update_dict.update(self.optimizer_config_group.update_ui_components_with_config_data(config.optimizer))
# Sanity check to catch if we accidentally forget to update a UI component.
assert set(update_dict.keys()) == set(self.get_ui_output_components())
return update_dict
def update_config_with_ui_component_data(
self, orig_config: SdxlTextualInversionConfig, ui_data: dict[gr.components.Component, typing.Any]
) -> SdxlTextualInversionConfig:
new_config = orig_config.model_copy(deep=True)
new_config.model = ui_data.pop(self.model)
new_config.hf_variant = ui_data.pop(self.hf_variant) or None
new_config.vae_model = ui_data.pop(self.vae_model) or None
new_config.num_vectors = ui_data.pop(self.num_vectors)
new_config.placeholder_token = ui_data.pop(self.placeholder_token)
new_config.initializer_token = ui_data.pop(self.initializer_token) or None
new_config.initial_phrase = ui_data.pop(self.initial_phrase) or None
new_config.max_checkpoints = ui_data.pop(self.max_checkpoints)
new_config.lr_scheduler = ui_data.pop(self.lr_scheduler)
new_config.lr_warmup_steps = ui_data.pop(self.lr_warmup_steps)
new_config.use_masks = ui_data.pop(self.use_masks)
max_grad_norm_value = ui_data.pop(self.max_grad_norm)
new_config.max_grad_norm = None if max_grad_norm_value == 0 else max_grad_norm_value
new_config.train_batch_size = ui_data.pop(self.train_batch_size)
new_config.cache_vae_outputs = ui_data.pop(self.cache_vae_outputs)
new_config.enable_cpu_offload_during_validation = ui_data.pop(self.enable_cpu_offload_during_validation)
new_config.gradient_accumulation_steps = ui_data.pop(self.gradient_accumulation_steps)
new_config.weight_dtype = ui_data.pop(self.weight_dtype)
new_config.gradient_checkpointing = ui_data.pop(self.gradient_checkpointing)
new_config.min_snr_gamma = ui_data.pop(self.min_snr_gamma)
new_config.num_validation_images_per_prompt = ui_data.pop(self.num_validation_images_per_prompt)
positive_prompts, negative_prompts = convert_ui_prompts_to_pos_neg_prompts(ui_data.pop(self.validation_prompts))
new_config.validation_prompts = positive_prompts
new_config.negative_validation_prompts = negative_prompts
new_config.data_loader = (
self.textual_inversion_sd_data_loader_config_group.update_config_with_ui_component_data(
new_config.data_loader, ui_data
)
)
new_config = self.base_pipeline_config_group.update_config_with_ui_component_data(new_config, ui_data)
new_config.optimizer = self.optimizer_config_group.update_config_with_ui_component_data(
new_config.optimizer, ui_data
)
# We pop items from ui_data as we use them so that we can sanity check that all the input data was transferred
# to the config.
assert len(ui_data) == 0
return new_config
================================================
FILE: src/invoke_training/ui/config_groups/textual_inversion_sd_data_loader_config_group.py
================================================
from typing import Any
import gradio as gr
from invoke_training.config.data.data_loader_config import (
TextualInversionSDDataLoaderConfig,
)
from invoke_training.ui.config_groups.aspect_ratio_bucket_config_group import AspectRatioBucketConfigGroup
from invoke_training.ui.config_groups.dataset_config_group import DatasetConfigGroup
from invoke_training.ui.config_groups.ui_config_element import UIConfigElement
class TextualInversionSDDataLoaderConfigGroup(UIConfigElement):
def __init__(self):
with gr.Row():
with gr.Column(scale=1):
with gr.Tab("Data Source Configs"):
with gr.Group():
self.dataset = DatasetConfigGroup(
allowed_types=[
"HF_HUB_IMAGE_CAPTION_DATASET",
"IMAGE_CAPTION_JSONL_DATASET",
"IMAGE_CAPTION_DIR_DATASET",
"IMAGE_DIR_DATASET",
]
)
with gr.Column(scale=3):
with gr.Tab("Data Loading Configs"):
with gr.Group():
self.caption_preset = gr.Dropdown(
label="Caption Preset",
choices=["None", "style", "object"],
info="Only one of 'Caption Preset' or 'Caption Templates' should be set.\nSelect a Caption "
"Preset option to use a set of pre-configured templates.",
interactive=True,
)
self.caption_templates = gr.Textbox(
label="Caption Templates",
info="Only one of 'Caption Preset' or 'Caption Templates' should be set. Enter one template"
" per line. Each template should contain a single placeholder token slot indicated by '{}',"
" for example 'a photo of a {}'.",
lines=5,
interactive=True,
)
with gr.Row():
self.keep_original_captions = gr.Checkbox(
label="Keep Original Captions",
info="If True, the caption templates will be prepended to the original captions."
" If False, the caption templates will replace the original captions.",
interactive=True,
)
self.shuffle_caption_delimiter = gr.Textbox(
label="Shuffle Caption Delimiter",
info="Set captions to split on the provided delimiter (e.g. ',') and shuffled.",
interactive=True,
)
with gr.Row():
self.resolution = gr.Number(
label="Resolution",
info="The resolution for input images. All of the images in the dataset will be"
" resized to this resolution unless the aspect_ratio_buckets config is set.",
precision=0,
interactive=True,
)
self.dataloader_num_workers = gr.Number(
label="Dataloading Workers",
info="Number of subprocesses to use for data loading. 0 means that the data will"
" be loaded in the main process.",
precision=0,
interactive=True,
)
with gr.Row():
self.center_crop = gr.Checkbox(
label="Center Crop",
info="If set, input images will be center-cropped to the target resolution. Otherwise,"
" input images will be randomly cropped to the target resolution.",
interactive=True,
)
self.random_flip = gr.Checkbox(
label="Random Flip",
info="If set, random flip augmentations will be applied to input images.",
interactive=True,
)
with gr.Tab("Aspect Ratio Bucketing Configs"):
self.aspect_ratio_bucket_config_group = AspectRatioBucketConfigGroup()
def update_ui_components_with_config_data(
self, config: TextualInversionSDDataLoaderConfig
) -> dict[gr.components.Component, Any]:
# Special handling of caption_preset to translate None to "None".
caption_preset = "None"
if config.caption_preset is not None:
caption_preset = config.caption_preset
update_dict = {
self.caption_preset: caption_preset,
self.caption_templates: "\n".join(config.caption_templates or []),
self.keep_original_captions: config.keep_original_captions,
self.shuffle_caption_delimiter: config.shuffle_caption_delimiter,
self.resolution: config.resolution,
self.center_crop: config.center_crop,
self.random_flip: config.random_flip,
self.dataloader_num_workers: config.dataloader_num_workers,
}
update_dict.update(self.dataset.update_ui_components_with_config_data(config.dataset))
update_dict.update(
self.aspect_ratio_bucket_config_group.update_ui_components_with_config_data(config.aspect_ratio_buckets)
)
return update_dict
def update_config_with_ui_component_data(
self, orig_config: TextualInversionSDDataLoaderConfig, ui_data: dict[gr.components.Component, Any]
) -> TextualInversionSDDataLoaderConfig:
new_config = orig_config.model_copy(deep=True)
# Special handling of caption_preset to translate "None" to None.
caption_presets = {"None": None, "style": "style", "object": "object"}
caption_preset = caption_presets[ui_data.pop(self.caption_preset)]
# Special handling of caption_templates.
caption_templates: list[str] = ui_data.pop(self.caption_templates).split("\n")
caption_templates = [x.strip() for x in caption_templates if x.strip() != ""] or None
new_config.dataset = self.dataset.update_config_with_ui_component_data(orig_config.dataset, ui_data)
new_config.aspect_ratio_buckets = self.aspect_ratio_bucket_config_group.update_config_with_ui_component_data(
orig_config.aspect_ratio_buckets, ui_data
)
new_config.caption_preset = caption_preset
new_config.caption_templates = caption_templates
new_config.keep_original_captions = ui_data.pop(self.keep_original_captions)
new_config.shuffle_caption_delimiter = ui_data.pop(self.shuffle_caption_delimiter) or None
new_config.resolution = ui_data.pop(self.resolution)
new_config.center_crop = ui_data.pop(self.center_crop)
new_config.random_flip = ui_data.pop(self.random_flip)
new_config.dataloader_num_workers = ui_data.pop(self.dataloader_num_workers)
return new_config
================================================
FILE: src/invoke_training/ui/config_groups/ui_config_element.py
================================================
from typing import Any
import gradio as gr
class UIConfigElement:
"""A base class for UI blocks that represent a part of a config."""
def get_ui_output_components(self) -> list[gr.components.Component]:
"""Recursively return a list of all valid output UI components."""
all_ui_components = []
for attribute in vars(self).values():
if isinstance(attribute, (gr.components.Component, gr.Group)):
all_ui_components.append(attribute)
elif isinstance(attribute, UIConfigElement):
all_ui_components.extend(attribute.get_ui_output_components())
return all_ui_components
def get_ui_input_components(self) -> list[gr.components.Component]:
"""Recursively return a list of all valid input UI components."""
all_ui_components = []
for attribute in vars(self).values():
if isinstance(attribute, (gr.components.Component)):
all_ui_components.append(attribute)
elif isinstance(attribute, UIConfigElement):
all_ui_components.extend(attribute.get_ui_input_components())
return all_ui_components
def update_ui_components_with_config_data(self, config) -> dict[gr.components.Component, Any]:
"""Produce a dictionary of UI components to their corresponding updated data from the config."""
raise NotImplementedError()
def update_config_with_ui_component_data(self, orig_config, ui_data: dict[gr.components.Component, Any]):
"""Update the orig_config with the data from the UI components. Return the updated config."""
raise NotImplementedError()
================================================
FILE: src/invoke_training/ui/gradio_blocks/header.py
================================================
import gradio as gr
from invoke_training.ui.utils.utils import get_assets_dir_path
class Header:
def __init__(self):
logo_path = get_assets_dir_path() / "logo.png"
gr.Image(
value=logo_path,
label="Invoke Training App",
width=200,
interactive=False,
container=False,
)
gr.Markdown(
"[Home](/)\n\n"
"*Invoke Training* - [Documentation](https://invoke-ai.github.io/invoke-training/) --"
" Learn more about Invoke at [invoke.com](https://www.invoke.com/)"
)
================================================
FILE: src/invoke_training/ui/gradio_blocks/pipeline_tab.py
================================================
import typing
import gradio as gr
import yaml
from invoke_training.config.pipeline_config import PipelineConfig
from invoke_training.ui.config_groups.ui_config_element import UIConfigElement
from invoke_training.ui.utils.utils import load_config_from_yaml
class PipelineTab:
def __init__(
self,
name: str,
default_config_file_path: str,
pipeline_config_cls: typing.Type[PipelineConfig],
config_group_cls: typing.Type[UIConfigElement],
run_training_cb: typing.Callable[[PipelineConfig], None],
app: gr.Blocks,
):
"""A tab for a single training pipeline type.
Args:
run_training_cb (typing.Callable[[PipelineConfig], None]): A callback function to run the training process.
"""
self._name = name
self._default_config_file_path = default_config_file_path
self._pipeline_config_cls = pipeline_config_cls
self._run_training_cb = run_training_cb
# self._default_config is the config that was last loaded from the reference config file.
self._default_config = None
# self._current_config is the config that was most recently generated from the UI.
self._current_config = None
gr.Markdown(f"# {self._name} Training Config")
self.reference_config_file = gr.Textbox(
label="Reference Config File Path", value=default_config_file_path, interactive=True
)
reset_config_button = gr.Button(value="Reload reference config")
self.pipeline_config_group = config_group_cls()
gr.Markdown("## Config Output")
generate_config_button = gr.Button(value="Generate Config")
self._config_yaml = gr.Code(label="Config YAML", language="yaml", interactive=False)
gr.Markdown(
"""# Run Training
'Start Training' starts the training process in the background. Check the terminal for logs.
**Warning: Click 'Generate Config' to capture all of the latest changes before starting training.**
"""
)
run_training_button = gr.Button(value="Start Training")
gr.Markdown(
"""# Visualize Results
Once you've started training, you can see the results by launching tensorboard with the following
command:
```bash
tensorboard --logdir /path/to/output_dir
```
Alternatively, you can browse the output directory directly to find model checkpoints, logs, and validation
images.
"""
)
reset_config_button.click(
self.on_reset_config_button_click,
inputs=self.reference_config_file,
outputs=self.pipeline_config_group.get_ui_output_components() + [self._config_yaml],
)
generate_config_button.click(
self.on_generate_config_button_click,
inputs=set(self.pipeline_config_group.get_ui_input_components()),
outputs=self.pipeline_config_group.get_ui_output_components() + [self._config_yaml],
)
run_training_button.click(self.on_run_training_button_click, inputs=[], outputs=[])
# On app load, reset the configs based on the default reference config file.
# We'll wrap this in a try-except block to handle any errors during loading
def safe_load_config(file_path):
try:
return self.on_reset_config_button_click(file_path)
except Exception as e:
print(f"Error during app.load for {self._name}: {e}")
# Return empty values for all outputs to avoid UI errors
output_components = self.pipeline_config_group.get_ui_output_components() + [self._config_yaml]
return {comp: None for comp in output_components}
app.load(
safe_load_config,
inputs=self.reference_config_file,
outputs=self.pipeline_config_group.get_ui_output_components() + [self._config_yaml],
)
def on_reset_config_button_click(self, file_path: str):
try:
print(f"Resetting UI configs for {self._name} to {file_path}.")
default_config = load_config_from_yaml(file_path)
if not isinstance(default_config, self._pipeline_config_cls):
raise TypeError(
f"Wrong config type. Expected '{self._pipeline_config_cls.__name__}', got "
f"'{type(default_config).__name__}'."
)
self._default_config = default_config
self._current_config = self._default_config.model_copy(deep=True)
update_dict = self.pipeline_config_group.update_ui_components_with_config_data(self._current_config)
update_dict.update({self._config_yaml: None})
return update_dict
except Exception as e:
print(f"Error resetting config: {e}")
# Return a minimal update dict to avoid UI errors
if self._current_config:
return {
self._config_yaml: yaml.safe_dump(
self._current_config.model_dump(), default_flow_style=False, sort_keys=False
)
}
return {self._config_yaml: f"Error loading config: {e}"}
def on_generate_config_button_click(self, data: dict):
try:
print(f"Generating config for {self._name}.")
self._current_config = self.pipeline_config_group.update_config_with_ui_component_data(
self._current_config, data
)
# Roundtrip to make sure that the config is valid.
self._current_config = self._pipeline_config_cls.model_validate(self._current_config.model_dump())
# Update the UI to reflect the new state of the config
# (in case some values were rounded or otherwise modified
# in the process).
update_dict = self.pipeline_config_group.update_ui_components_with_config_data(self._current_config)
update_dict.update(
{
self._config_yaml: yaml.safe_dump(
self._current_config.model_dump(), default_flow_style=False, sort_keys=False
)
}
)
return update_dict
except Exception as e:
print(f"Error generating config: {e}")
# Return a minimal update dict to avoid UI errors
if self._current_config:
return {
self._config_yaml: yaml.safe_dump(
self._current_config.model_dump(), default_flow_style=False, sort_keys=False
)
}
return {self._config_yaml: f"Error generating config: {e}"}
def on_run_training_button_click(self):
self._run_training_cb(self._current_config)
================================================
FILE: src/invoke_training/ui/index.html
================================================
invoke-training