Repository: invoke-ai/invoke-training Branch: main Commit: 363f83cdb5e6 Files: 246 Total size: 930.6 KB Directory structure: gitextract_5kosyax7/ ├── .github/ │ └── workflows/ │ ├── deploy.yaml │ └── test.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── docs/ │ ├── contributing/ │ │ ├── development_environment.md │ │ ├── directory_structure.md │ │ ├── documentation.md │ │ └── tests.md │ ├── get-started/ │ │ ├── installation.md │ │ └── quick-start.md │ ├── guides/ │ │ ├── dataset_formats.md │ │ ├── model_merge.md │ │ └── stable_diffusion/ │ │ ├── dpo_lora_sd.md │ │ ├── gnome_lora_masks_sdxl.md │ │ ├── robocats_finetune_sdxl.md │ │ └── textual_inversion_sdxl.md │ ├── index.md │ ├── reference/ │ │ └── config/ │ │ ├── index.md │ │ ├── pipelines/ │ │ │ ├── sd_lora.md │ │ │ ├── sd_textual_inversion.md │ │ │ ├── sdxl_finetune.md │ │ │ ├── sdxl_lora.md │ │ │ ├── sdxl_lora_and_textual_inversion.md │ │ │ └── sdxl_textual_inversion.md │ │ └── shared/ │ │ ├── data/ │ │ │ ├── data_loader_config.md │ │ │ └── dataset_config.md │ │ └── optimizer_config.md │ └── templates/ │ └── python/ │ └── material/ │ └── labels.html ├── mkdocs.yml ├── pyproject.toml ├── sample_data/ │ └── bruce_the_gnome/ │ └── data.jsonl ├── src/ │ └── invoke_training/ │ ├── __init__.py │ ├── _shared/ │ │ ├── __init__.py │ │ ├── accelerator/ │ │ │ ├── __init__.py │ │ │ └── accelerator_utils.py │ │ ├── checkpoints/ │ │ │ ├── __init__.py │ │ │ ├── checkpoint_tracker.py │ │ │ ├── lora_checkpoint_utils.py │ │ │ └── serialization.py │ │ ├── data/ │ │ │ ├── ARCHITECTURE.md │ │ │ ├── __init__.py │ │ │ ├── data_loaders/ │ │ │ │ ├── __init__.py │ │ │ │ ├── dreambooth_sd_dataloader.py │ │ │ │ ├── image_caption_flux_dataloader.py │ │ │ │ ├── image_caption_sd_dataloader.py │ │ │ │ ├── image_pair_preference_sd_dataloader.py │ │ │ │ └── textual_inversion_sd_dataloader.py │ │ │ ├── datasets/ │ │ │ │ ├── __init__.py │ │ │ │ ├── build_dataset.py │ │ │ │ ├── hf_image_caption_dataset.py │ │ │ │ ├── hf_image_pair_preference_dataset.py │ │ │ │ ├── image_caption_dir_dataset.py │ │ │ │ ├── image_caption_jsonl_dataset.py │ │ │ │ ├── image_dir_dataset.py │ │ │ │ ├── image_pair_preference_dataset.py │ │ │ │ └── transform_dataset.py │ │ │ ├── samplers/ │ │ │ │ ├── __init__.py │ │ │ │ ├── aspect_ratio_bucket_batch_sampler.py │ │ │ │ ├── batch_offset_sampler.py │ │ │ │ ├── concat_sampler.py │ │ │ │ ├── interleaved_sampler.py │ │ │ │ └── offset_sampler.py │ │ │ ├── transforms/ │ │ │ │ ├── __init__.py │ │ │ │ ├── caption_prefix_transform.py │ │ │ │ ├── concat_fields_transform.py │ │ │ │ ├── constant_field_transform.py │ │ │ │ ├── drop_field_transform.py │ │ │ │ ├── flux_image_transform.py │ │ │ │ ├── load_cache_transform.py │ │ │ │ ├── sd_image_transform.py │ │ │ │ ├── shuffle_caption_transform.py │ │ │ │ ├── template_caption_transform.py │ │ │ │ └── tensor_disk_cache.py │ │ │ └── utils/ │ │ │ ├── __init__.py │ │ │ ├── aspect_ratio_bucket_manager.py │ │ │ ├── resize.py │ │ │ └── resolution.py │ │ ├── flux/ │ │ │ ├── encoding_utils.py │ │ │ ├── lora_checkpoint_utils.py │ │ │ ├── model_loading_utils.py │ │ │ └── validation.py │ │ ├── optimizer/ │ │ │ ├── __init__.py │ │ │ └── optimizer_utils.py │ │ ├── stable_diffusion/ │ │ │ ├── __init__.py │ │ │ ├── base_model_version.py │ │ │ ├── checkpoint_utils.py │ │ │ ├── lora_checkpoint_utils.py │ │ │ ├── min_snr_weighting.py │ │ │ ├── model_loading_utils.py │ │ │ ├── textual_inversion.py │ │ │ ├── tokenize_captions.py │ │ │ └── validation.py │ │ ├── tools/ │ │ │ ├── __init__.py │ │ │ └── generate_images.py │ │ └── utils/ │ │ ├── import_xformers.py │ │ └── jsonl.py │ ├── config/ │ │ ├── __init__.py │ │ ├── base_pipeline_config.py │ │ ├── config_base_model.py │ │ ├── data/ │ │ │ ├── __init__.py │ │ │ ├── data_loader_config.py │ │ │ └── dataset_config.py │ │ ├── optimizer/ │ │ │ ├── __init__.py │ │ │ └── optimizer_config.py │ │ └── pipeline_config.py │ ├── model_merge/ │ │ ├── __init__.py │ │ ├── extract_lora.py │ │ ├── merge_models.py │ │ ├── merge_tasks_to_base.py │ │ ├── scripts/ │ │ │ ├── extract_lora_from_model_diff.py │ │ │ ├── merge_lora_into_model.py │ │ │ ├── merge_models.py │ │ │ └── merge_task_models_to_base_model.py │ │ └── utils/ │ │ ├── __init__.py │ │ ├── normalize_weights.py │ │ └── parse_model_arg.py │ ├── pipelines/ │ │ ├── __init__.py │ │ ├── _experimental/ │ │ │ └── sd_dpo_lora/ │ │ │ ├── config.py │ │ │ └── train.py │ │ ├── callbacks.py │ │ ├── flux/ │ │ │ └── lora/ │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ └── train.py │ │ ├── invoke_train.py │ │ ├── stable_diffusion/ │ │ │ ├── __init__.py │ │ │ ├── lora/ │ │ │ │ ├── __init__.py │ │ │ │ ├── config.py │ │ │ │ └── train.py │ │ │ └── textual_inversion/ │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ └── train.py │ │ └── stable_diffusion_xl/ │ │ ├── __init__.py │ │ ├── finetune/ │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ └── train.py │ │ ├── lora/ │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ └── train.py │ │ ├── lora_and_textual_inversion/ │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ └── train.py │ │ └── textual_inversion/ │ │ ├── __init__.py │ │ ├── config.py │ │ └── train.py │ ├── sample_configs/ │ │ ├── _experimental/ │ │ │ ├── sd_dpo_lora_pickapic_1x24gb.yaml │ │ │ └── sd_dpo_lora_refinement_pokemon_1x24gb.yaml │ │ ├── flux_lora_1x40gb.yaml │ │ ├── sd_lora_baroque_1x8gb.yaml │ │ ├── sd_textual_inversion_gnome_1x8gb.yaml │ │ ├── sdxl_finetune_baroque_1x24gb.yaml │ │ ├── sdxl_finetune_robocats_1x24gb.yaml │ │ ├── sdxl_lora_and_ti_gnome_1x24gb.yaml │ │ ├── sdxl_lora_baroque_1x24gb.yaml │ │ ├── sdxl_lora_baroque_1x8gb.yaml │ │ ├── sdxl_lora_masks_gnome_1x24gb.yaml │ │ ├── sdxl_textual_inversion_gnome_1x24gb.yaml │ │ └── sdxl_textual_inversion_masks_gnome_1x24gb.yaml │ ├── scripts/ │ │ ├── __init__.py │ │ ├── _experimental/ │ │ │ ├── auto_caption/ │ │ │ │ └── auto_caption_images.py │ │ │ ├── masks/ │ │ │ │ ├── clipseg.py │ │ │ │ ├── generate_masks.py │ │ │ │ └── generate_masks_for_jsonl_dataset.py │ │ │ └── rank_images.py │ │ ├── convert_sd_lora_to_kohya_format.py │ │ ├── invoke_generate_images.py │ │ ├── invoke_train.py │ │ ├── invoke_train_ui.py │ │ ├── invoke_visualize_data_loading.py │ │ └── utils/ │ │ └── image_dir_dataset.py │ └── ui/ │ ├── __init__.py │ ├── app.py │ ├── config_groups/ │ │ ├── __init__.py │ │ ├── aspect_ratio_bucket_config_group.py │ │ ├── base_pipeline_config_group.py │ │ ├── dataset_config_group.py │ │ ├── flux_lora_config_group.py │ │ ├── image_caption_sd_data_loader_config_group.py │ │ ├── optimizer_config_group.py │ │ ├── sd_lora_config_group.py │ │ ├── sd_textual_inversion_config_group.py │ │ ├── sdxl_finetune_config_group.py │ │ ├── sdxl_lora_and_textual_inversion_config_group.py │ │ ├── sdxl_lora_config_group.py │ │ ├── sdxl_textual_inversion_config_group.py │ │ ├── textual_inversion_sd_data_loader_config_group.py │ │ └── ui_config_element.py │ ├── gradio_blocks/ │ │ ├── header.py │ │ └── pipeline_tab.py │ ├── index.html │ ├── pages/ │ │ ├── data_page.py │ │ └── training_page.py │ └── utils/ │ ├── prompts.py │ └── utils.py └── tests/ └── invoke_training/ ├── _shared/ │ ├── __init__.py │ ├── checkpoints/ │ │ ├── test_checkpoint_tracker.py │ │ └── test_serialization.py │ ├── data/ │ │ ├── __init__.py │ │ ├── data_loaders/ │ │ │ ├── __init__.py │ │ │ ├── test_dreambooth_sd_dataloader.py │ │ │ ├── test_image_caption_sd_dataloader.py │ │ │ ├── test_image_pair_preference_sd_dataloader.py │ │ │ └── test_textual_inversion_sd_dataloader.py │ │ ├── dataset_fixtures.py │ │ ├── datasets/ │ │ │ ├── __init__.py │ │ │ ├── test_hf_image_caption_dataset.py │ │ │ ├── test_hf_image_pair_preference_dataset.py │ │ │ ├── test_image_caption_dir_dataset.py │ │ │ ├── test_image_caption_jsonl_dataset.py │ │ │ ├── test_image_dir_dataset.py │ │ │ ├── test_image_pair_preference_dataset.py │ │ │ └── test_transform_dataset.py │ │ ├── samplers/ │ │ │ ├── __init__.py │ │ │ ├── test_aspect_ratio_bucket_batch_sampler.py │ │ │ ├── test_batch_offset_sampler.py │ │ │ ├── test_concat_sampler.py │ │ │ ├── test_interleaved_sampler.py │ │ │ └── test_offset_sampler.py │ │ ├── transforms/ │ │ │ ├── __init__.py │ │ │ ├── test_caption_prefix_transform.py │ │ │ ├── test_concat_fields_transform.py │ │ │ ├── test_constant_field_transform.py │ │ │ ├── test_drop_field_transform.py │ │ │ ├── test_load_cache_transform.py │ │ │ ├── test_sd_image_transform.py │ │ │ ├── test_shuffle_caption_transform.py │ │ │ ├── test_template_caption_transform.py │ │ │ └── test_tensor_disk_cache.py │ │ └── utils/ │ │ ├── __init__.py │ │ ├── test_aspect_ratio_bucket_manager.py │ │ ├── test_resize.py │ │ └── test_resolution.py │ ├── stable_diffusion/ │ │ ├── __init__.py │ │ ├── test_base_model_version.py │ │ ├── test_lora_checkpoint_utils.py │ │ ├── test_model_loading_utils.py │ │ ├── test_textual_inversion.py │ │ └── ti_embedding_checkpoint_fixture.py │ └── utils/ │ └── test_jsonl.py ├── config/ │ └── pipelines/ │ └── test_pipeline_config.py ├── model_merge/ │ ├── __init__.py │ ├── test_merge_models.py │ ├── test_merge_tasks_to_base.py │ └── utils.py └── ui/ └── utils/ └── test_prompts.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/deploy.yaml ================================================ name: Deploy invoke-training docs on: push: branches: - main permissions: contents: write jobs: deploy: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - name: Configure Git Credentials run: | git config user.name github-actions[bot] git config user.email 41898282+github-actions[bot]@users.noreply.github.com - uses: actions/setup-python@v4 with: python-version: "3.10" cache: pip cache-dependency-path: pyproject.toml - name: Install dependencies run: | python -m pip install --upgrade pip python -m pip install .[test] - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV - uses: actions/cache@v3 with: key: mkdocs-material-${{ env.cache_id }} path: .cache restore-keys: | mkdocs-material- - run: mkdocs gh-deploy --force ================================================ FILE: .github/workflows/test.yaml ================================================ name: Test invoke-training on: push: branches: - main pull_request: workflow_dispatch: jobs: build: runs-on: ubuntu-latest strategy: fail-fast: false matrix: python-version: ["3.12"] steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} cache: pip cache-dependency-path: pyproject.toml - name: Install dependencies run: | python -m pip install --upgrade pip python -m pip install .[test] - name: Ruff lint run: | ruff check --output-format=github . - name: Ruff format run: | ruff format --check . - name: Test with pytest run: | pytest tests --junitxml=junit/test-results-${{ matrix.python-version }}.xml -m "not cuda and not loads_model" - name: Upload pytest test results uses: actions/upload-artifact@v4 with: name: pytest-results-${{ matrix.python-version }} path: junit/test-results-${{ matrix.python-version }}.xml # Use always() to always run this step to publish test results when there are test failures. if: ${{ always() }} ================================================ FILE: .gitignore ================================================ /output/ /test_configs/ /data/ # pyenv .python-version # VSCode .vscode/ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ junit/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # poetry # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control #poetry.lock # pdm # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. #pdm.lock # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it # in version control. # https://pdm.fming.dev/#use-with-ide .pdm.toml # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ .aider* ================================================ FILE: .pre-commit-config.yaml ================================================ # See https://pre-commit.com/ for usage and config. repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. rev: v0.1.7 hooks: # Run the linter. - id: ruff # Run the formatter. - id: ruff-format ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # invoke-training A library for training custom Stable Diffusion models (fine-tuning, LoRA training, textual inversion, etc.) that can be used in [InvokeAI](https://github.com/invoke-ai/InvokeAI). > [!WARNING] > `invoke-training` is still under active development, and breaking changes are likely. Full backwards compatibility will not be guaranteed until v1.0.0. > In the meantime, I recommend pinning to a specific commit hash. ## Documentation ## Training Modes - Stable Diffusion - LoRA - DreamBooth LoRA - Textual Inversion - Stable Diffusion XL - Full finetuning - LoRA - DreamBooth LoRA - Textual Inversion - LoRA and Textual Inversion More training modes coming soon! ## Installation See the [Installation](https://invoke-ai.github.io/invoke-training/get-started/installation/) section of the documentation. ## Quick Start `invoke-training` pipelines can be configured and launched from either the CLI or the GUI. ### CLI Run training via the CLI with type-checked YAML configuration files for maximum control: ```bash invoke-train --cfg-file src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml ``` ### GUI Run training via the GUI for a simpler starting point. ```bash invoke-train-ui # Or, you can optionally override the default host and port: invoke-train-ui --host 0.0.0.0 --port 1234 ``` ## Features Training progress can be monitored with [Tensorboard](https://www.tensorflow.org/tensorboard): ![Screenshot of the Tensorboard UI showing validation images.](docs/images/tensorboard_val_images_screenshot.png) _Validation images in the Tensorboard UI._ All trained models are compatible with InvokeAI: ![Screenshot of the InvokeAI UI with an example of a Yoda pokemon generated using a Pokemon LoRA model.](docs/images/invokeai_yoda_pokemon_lora.png) _Example image generated with the prompt "A cute yoda pokemon creature." and a trained Pokemon LoRA._ ## Contributing Contributors are welcome. For developer guidance, see the [Contributing](https://invoke-ai.github.io/invoke-training/contributing/development_environment/) section of the documentation. ================================================ FILE: docs/contributing/development_environment.md ================================================ # Development Environment Setup See the [developer installation instructions](../get-started/installation.md#developer-installation). ================================================ FILE: docs/contributing/directory_structure.md ================================================ # Directory Structure ```bash invoke-training/ ├── README.md ├── docs/ ├── src/ │ └── invoke-training/ │ ├── _shared/ # Utilities shared across multiple pipelines. Hight unit test coverage. │ ├── config/ # Config structures shared by multiple pipelines. │ ├── pipelines/ # Each pipeline is isolated in it's own directory with a train.py and config.py. │ │ ├── stable_diffusion/ │ │ │ ├── lora/ │ │ │ │ ├── config.py │ │ │ │ └── train.py │ │ │ └── textual_inversion/ │ │ │ └── ... │ │ ├── stable_diffusion_xl/ │ │ └── ... │ └── scripts/ # Main entrypoints. └── tests/ # Mirrors src/ directory. ``` ================================================ FILE: docs/contributing/documentation.md ================================================ # Documentation The documentation site is generated using [mkdocs](https://www.mkdocs.org/) and [mkdocstrings-python](https://mkdocstrings.github.io/python/). To view your documentation changes locally, run `mkdocs serve`. ================================================ FILE: docs/contributing/tests.md ================================================ # Tests Run all unit tests with: ```bash pytest tests/ ``` There are some test 'markers' defined in [pyproject.toml](https://github.com/invoke-ai/invoke-training/blob/main/pyproject.toml) that can be used to skip some tests. For example, the following command skips tests that require a GPU or require downloading model weights: ```bash pytest tests/ -m "not cuda and not loads_model" ``` ================================================ FILE: docs/get-started/installation.md ================================================ # Installation ## Requirements 1. Python 3.10, 3.11 and 3.12 are currently supported. Check your Python version by running `python -V`. 2. An NVIDIA GPU with >= 8 GB VRAM is recommended for model training. ## Basic Installation 0. Open your terminal and navigate to the directory where you want to clone the `invoke-training` repo. 1. Clone the repo: ```bash git clone https://github.com/invoke-ai/invoke-training.git ``` 2. Create and activate a python [virtual environment](https://docs.python.org/3/library/venv.html#creating-virtual-environments). This creates an isolated environment for `invoke-training` and its dependencies that won't interfere with other python environments on your system, including any installations of [InvokeAI](https://www.github.com/invoke-ai/invokeai). ```bash # Navigate to the invoke-training directory. cd invoke-training # Create a new virtual environment named `invoketraining`. python -m venv invoketraining # Activate the new virtual environment. # On Windows: .\invoketraining\Scripts\activate # On MacOS / Linux: source invoketraining/bin/activate ``` 3. Install `invoke-training` and its dependencies. Run the appropriate install command for your system. ```bash # A recent version of pip is required, so first upgrade pip: python -m pip install --upgrade pip # Install - Windows or Linux with a Nvidia GPU: pip install ".[test]" --extra-index-url https://download.pytorch.org/whl/cu126 # Install - Linux with no GPU: pip install ".[test]" --extra-index-url https://download.pytorch.org/whl/cpu # Install - All other systems: pip install ".[test]" ``` In the future, before you run `invoke-training`, you must activate the virtual environment you created during installation, using the same command you used during installation. ## Developer Installation Consider forking the repo if you plan to contribute code changes. Follow the above installation instructions, cloning your fork instead of this repo if you made a fork. Next, we suggest setting up the repo's pre-commit hooks to automatically format and lint your contributions: 1. (_Optional_) Install the pre-commit hooks: `pre-commit install`. This will run static analysis tools (ruff) on `git commit`. 2. (_Optional_) Setup `ruff` in your IDE of choice. ================================================ FILE: docs/get-started/quick-start.md ================================================ # Quick Start `invoke-training` has both a GUI and a CLI (for advanced users). The instructions for getting started with both options can be found on this page. There is also a video introduction to `invoke-training`: ## Quick Start - GUI ### 1. Installation Follow the [`invoke-training` installation instructions](./installation.md). ### 2. Launch the GUI Activate the virtual environment you created during installation, using the same command you used during installation. You'll need to do this every time you run `invoke-training`. ```bash # From the invoke-training directory: invoke-train-ui # Or, you can optionally override the default host and port: invoke-train-ui --host 0.0.0.0 --port 1234 ``` Access the GUI in your browser at the URL printed to the console. ### 3. Configure the training job Select the desired training pipeline type in the top-level tab. For this tutorial, we don't need to change any of the configuration values. The preset configuration should work well. ### 4. Generate the YAML configuration Click on 'Generate Config' to generate a YAML configuration file. This YAML configuration file could be used to launch the training job from the CLI, if desired. ### 5. Start training Click on the 'Start Training' and check your terminal for progress logs. ### 6. Monitor training Monitor the training process with Tensorboard by running `tensorboard --logdir output/` and visiting [localhost:6006](http://localhost:6006) in your browser. Here you can see generated validation images throughout the training process. ![Screenshot of the Tensorboard UI showing validation images.](../images/tensorboard_val_images_screenshot.png) _Validation images in the Tensorboard UI._ ### 7. Invokeai Select a checkpoint based on the quality of the generated images. If you haven't already, setup [InvokeAI](https://github.com/invoke-ai/InvokeAI) by following its documentation. Copy your selected LoRA checkpoint into your `${INVOKEAI_ROOT}/autoimport/lora` directory. For example: ```bash # Note: You will have to replace the timestamp in the checkpoint path. cp output/1691088769.5694647/checkpoint_epoch-00000002.safetensors ${INVOKEAI_ROOT}/autoimport/lora/pokemon_epoch-00000002.safetensors ``` You can now use your trained Pokemon LoRA in the InvokeAI UI! 🎉 ![Screenshot of the InvokeAI UI with an example of a Yoda pokemon generated using a Pokemon LoRA model.](../images/invokeai_yoda_pokemon_lora.png) _Example image generated with the prompt "A cute yoda pokemon creature." and Pokemon LoRA._ ## Quick Start - CLI ### 1. Installation Follow the [`invoke-training` installation instructions](./installation.md). ### 2. Training Activate the virtual environment you created during installation, using the same command you used during installation. You'll need to do this every time you run `invoke-training`. See the [Textual Inversion - SDXL](../guides/stable_diffusion/textual_inversion_sdxl.md) tutorial for instructions on how to train a model via the CLI. ================================================ FILE: docs/guides/dataset_formats.md ================================================ # Dataset Formats `invoke-training` supports the following dataset formats: - `IMAGE_CAPTION_JSONL_DATASET`: A local image-caption dataset described by a single `.jsonl` file. - `IMAGE_CAPTION_DIR_DATASET`: A local directory of images with associated `.txt` caption files. - `IMAGE_DIR_DATASET`: A local directory of images (without captions). - `HF_HUB_IMAGE_CAPTION_DATASET`: A Hugging Face Hub dataset containing images and captions. See the documentation for a particular training pipeline to see which dataset formats it supports. The following sections explain each of these formats in more detail. ## `IMAGE_CAPTION_JSONL_DATASET` Config documentation: [ImageCaptionJsonlDatasetConfig][invoke_training.config.data.dataset_config.ImageCaptionJsonlDatasetConfig] A `IMAGE_CAPTION_JSONL_DATASET` consists of a single `.jsonl` file containing image paths and associated captions. Sample directory structure: ```bash my_custom_dataset/ ├── data.jsonl └── train/ ├── 0001.png ├── 0002.png ├── 0003.png └── ... ``` The contents of `data.jsonl` would be: ```json {"file_name": "train/0001.png", "text": "This is a caption describing image 0001."} {"file_name": "train/0002.png", "text": "This is a caption describing image 0002."} {"file_name": "train/0003.png", "text": "This is a caption describing image 0003."} ``` The image file paths can be either absolute paths, or relative to the `.jsonl` file. Finally, this dataset can be used with the following pipeline dataset configuration: ```yaml type: IMAGE_CAPTION_JSONL_DATASET jsonl_path: /path/to/my_custom_dataset/metadata.jsonl image_column: file_name caption_column: text ``` A useful characteristic of this dataset format is that a `.jsonl` file can reference an image file anywhere on the local disk. It is common to maintain multiple `.jsonl` datasets that reference some of the same images without needing multiple copies of those images on disk. ## `IMAGE_CAPTION_DIR_DATASET` Config documentation: [ImageCaptionDirDataset][invoke_training.config.data.dataset_config.ImageCaptionDirDatasetConfig] A `IMAGE_CAPTION_DIR_DATASET` consists of a directory of image files and corresponding `.txt` caption files of the same name. Sample directory structure: ```bash my_custom_dataset/ ├── 0001.png ├── 0001.txt ├── 0002.jpg ├── 0002.txt ├── 0003.png ├── 0003.txt └── ... ``` Each `.txt` file should contain a caption on the first line of the file. Here are the sample contents of `0001.txt`: ```txt title="0001.txt" this is a caption for example 0001 ``` This dataset can be used with the following pipeline dataset configuration: ```yaml type: IMAGE_CAPTION_DIR_DATASET dataset_dir: /path/to/my_custom_dataset ``` ## `IMAGE_DIR_DATASET` Config documentation: [ImageDirDataset][invoke_training.config.data.dataset_config.ImageDirDatasetConfig] A `IMAGE_DIR_DATASET` consists of a single directory of images (without captions). Sample directory structure: ```bash my_custom_dataset/ ├── 0001.png ├── 0002.jpg ├── 0003.png └── ... ``` This dataset can be used with the following pipeline dataset configuration: ```yaml type: IMAGE_DIR_DATASET dataset_dir: /path/to/my_custom_dataset ``` ## `HF_HUB_IMAGE_CAPTION_DATASET` Config documentation: [HFHubImageCaptionDatasetConfig][invoke_training.config.data.dataset_config.HFHubImageCaptionDatasetConfig] The `HF_HUB_IMAGE_CAPTION_DATASET` dataset format can be used to access publicly datasets on the [Hugging Face Hub](https://huggingface.co/datasets). You can filter for the `Text-to-Image` task to find relevant datasets that contain both an image column and a caption column. [lambdalabs/pokemon-blip-captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) is a popular choice if you're not sure where to start. ================================================ FILE: docs/guides/model_merge.md ================================================ # Model Merging `invoke-training` provides utility scripts for several common model merging workflows. This page contains a summary of the available tools. ## `extract_lora_from_model_diff.py` Extract a LoRA model that represents the difference between two base models. Note that the extracted LoRA model is a lossy representation of the difference between the models, so some degradation in quality is expected. For usage docs, run: ```bash python src/invoke_training/model_merge/scripts/extract_lora_from_model_diff.py -h ``` ## `merge_lora_into_model.py` Merge a LoRA model into a base model to produce a new base model. For usage docs, run: ```bash python src/invoke_training/model_merge/scripts/merge_lora_into_model.py -h ``` ## `merge_models.py` Merge 2 or more base models to produce a single base model (using either LERP or SLERP). This is a simple merge strategy that merges all model weights in the same way. For usage docs, run: ```bash python src/invoke_training/model_merge/scripts/merge_models.py -h ``` ## `merge_task_models_to_base_model.py` Merge 1 or more task-specific base models into a single starting base model (using either [TIES](https://arxiv.org/abs/2306.01708) or [DARE](https://arxiv.org/abs/2311.03099)). This merge strategy aims to preserve the task-specific behaviors of the task models while making only small changes to the original base model. This approach enables multiple task models to be merged without excessive interference between them. If you want to merge a task-specific LoRA into a base model using this strategy, first use `merge_lora_into_model.py` to produce a task-specific base model, then merge that new base model using this strategy. For usage docs, run: ```bash python src/invoke_training/model_merge/scripts/merge_task_models_to_base_model.py -h ``` ================================================ FILE: docs/guides/stable_diffusion/dpo_lora_sd.md ================================================ # (Experimental) Diffusion DPO - SD !!! tip "Experimental" The Diffusion Direct Preference Optimization training pipeline is still experimental. Support may be dropped at any time. This tutorial walks through some initial experiments around using Diffusion Direct Preference Optimization (DPO) ([paper](https://arxiv.org/abs/2311.12908)) to train Stable Diffusion LoRA models. ## Experiment 1: `pickapic_v2` LoRA Training The Diffusion-DPO paper does full model fine-tuning on the [pickapic_v2](https://huggingface.co/datasets/yuvalkirstain/pickapic_v2) dataset, which consists of roughly 1M AI-generated image pairs with preference annotations. In this experiment, we attempt to fine-tune a Stable Diffusion LoRA model using a small subset of the pickapic_v2 dataset. Run this experiment with the following command: ```bash invoke-train -c src/invoke_training/sample_configs/_experimental/sd_dpo_lora_pickapic_1x24gb.yaml ``` Here is a cherry-picked example of a prompt for which this training process was clearly beneficial. Prompt: "*A galaxy-colored figurine is floating over the sea at sunset, photorealistic*" | Before DPO Training | After DPO Training (same seed)| | - | - | | ![Sample image before DPO training.](../../images/dpo/before_dpo.jpg) | ![Sample image after DPO training.](../../images/dpo/after_dpo.jpg) | ## Experiment 2: LoRA Model Refinement As a second experiment, we attempt the following workflow: 1. Train a Stable Diffusion LoRA model on a particular style. 2. Generate pairs of images of the character with the trained LoRA model. 3. Annotate the preferred image from each pair. 4. Apply Diffusion-DPO to the preference-annotated pairs to further fine-tune the LoRA model. Note: The steps listed below are pretty rough. They are included primarily for reference for someone looking to resume this line of work in the future. ### 1. Train a style LoRA ```bash invoke-train -c src/invoke_training/sample_configs/sd_lora_pokemon_1x8gb.yaml ``` ### 2. Generate images Prepare ~100 relevant prompts that will be used to generate training data with the freshly-trained LoRA model. Add the prompts to a `.txt` file - one prompt per line. Example prompts: ```txt a cute orange pokemon character with pointy ears a drawing of a purple fish a cartoon blob with a smile on its face a drawing of a snail with big eyes ... ``` ```bash # Convert the LoRA checkpoint of interest to Kohya format. # You will have to change the path timestamps in this example command. # TODO(ryand): This manual conversion shouldn't be necessary. python src/invoke_training/scripts/convert_sd_lora_to_kohya_format.py \ --src-ckpt-dir output/sd_lora_pokemon/1704824279.2765746/checkpoint_epoch-00000003/ \ --dst-ckpt-file output/sd_lora_pokemon/1704824279.2765746/checkpoint_epoch-00000003_kohya.safetensors # Generate 2 pairs of images for each prompt. invoke-generate-images \ -o output/pokemon_pairs \ -m runwayml/stable-diffusion-v1-5 \ -v fp16 \ -l output/sd_lora_pokemon/1704824279.2765746/checkpoint_epoch-00000003_kohya.safetensors \ --sd-version SD \ --prompt-file path/to/prompts.txt \ --set-size 2 \ --num-sets 2 \ --height 512 \ --width 512 ``` ### 3. Annotate the image pair preferences Launch the gradio UI for selecting image pair preferences. ```bash # Note: rank_images.py accepts a full training pipeline config, but only uses the dataset configuration. python src/invoke_training/scripts/_experimental/rank_images.py -c src/invoke_training/sample_configs/_experimental/sd_dpo_lora_refinement_pokemon_1x24gb.yaml ``` After completing the pair annotations, click "Save Metadata" and move the resultant metadata file to your image data directory (e.g. `output/pokemon_pairs/metadata.jsonl`). ### 4. Run Diffusion-DPO ```bash invoke-train -c src/invoke_training/sample_configs/_experimental/sd_dpo_lora_refinement_pokemon_1x24gb.yaml ``` ================================================ FILE: docs/guides/stable_diffusion/gnome_lora_masks_sdxl.md ================================================ # LoRA with Masks - SDXL This tutorial explains how to prepare masks for an image dataset and then use that dataset to train an SDXL LoRA model. Masks can be used to weight regions of images in a dataset to control how much they contribute to the training process. In this tutorial we will use masks to train on a small dataset of images of Bruce the Gnome (4 images). With such a small dataset, there is a high risk of overfitting to the background elements from the images. We will use masks to avoid this problem ond focus only on the object of interest. ## 1 - Dataset Preparation For this tutorial, we'll use a dataset consisting of 4 images of Bruce the Gnome: | | | | - | - | | ![bruce_the_gnome dataset image 1.](../../images/bruce_the_gnome/001.jpg) | ![bruce_the_gnome dataset image 2.](../../images/bruce_the_gnome/002.jpg) | | ![bruce_the_gnome dataset image 3.](../../images/bruce_the_gnome/003.jpg) | ![bruce_the_gnome dataset image 4.](../../images/bruce_the_gnome/004.jpg) | This sample dataset is included in the invoke-training repo under [sample_data/bruce_the_gnome](https://github.com/invoke-ai/invoke-training/tree/main/sample_data/bruce_the_gnome). ## 2 - Generate Masks Use the `generate_masks_for_jsonl_dataset.py` script to generate masks for your dataset based on a single prompt. In this case we are using the prompt `"a stuffed gnome"`: ```bash python src/invoke_training/scripts/_experimental/masks/generate_masks_for_jsonl_dataset.py \ --in-jsonl sample_data/bruce_the_gnome/data.jsonl \ --out-jsonl sample_data/bruce_the_gnome/data_masks.jsonl \ --prompt "a stuffed gnome" ``` The mask generation script will produce the following outputs: - A directory of generated masks: `sample_data/bruce_the_gnome/masks/` - A new `.jsonl` file that references the mask images: `sample_data/bruce_the_gnome/data_masks.jsonl` ## 3 - Review the Generated Masks Review the generated masks to make sure that the target regions were masked. You may need to adjust the prompt and re-generate the masks to achieve the desired result. Alternatively, you can edit the masks manually. The masks are simply single-channel grayscale images (0=background, 255=foreground). Here are some examples of the masks that we just generated: | | | | - | - | | ![bruce_the_gnome dataset image 1.](../../images/bruce_the_gnome/001.jpg) | ![bruce_the_gnome dataset image 1 mask.](../../images/bruce_masks/001_mask.png) | | ![bruce_the_gnome dataset image 2.](../../images/bruce_the_gnome/002.jpg) | ![bruce_the_gnome dataset image 2 mask.](../../images/bruce_masks/002_mask.png) | ## 4 - Configuration Below is the training configuration that we'll use for this tutorial. Raw config file: [src/invoke_training/sample_configs/sdxl_lora_masks_gnome_1x24gb.yaml](https://github.com/invoke-ai/invoke-training/blob/main/src/invoke_training/sample_configs/sdxl_lora_masks_gnome_1x24gb.yaml). ```yaml title="sdxl_lora_masks_gnome_1x24gb.yaml" --8<-- "src/invoke_training/sample_configs/sdxl_lora_masks_gnome_1x24gb.yaml" ``` Full documentation of all of the configuration options is here: [LoRA SDXL Config](../../reference/config/pipelines/sdxl_lora.md) There are few things to note about this training config: - We set `use_masks: True` in order to use the masks that we generated. This configuration is only compatible with datasets that have mask data. - The `learning_rate`, `max_train_steps`, `save_every_n_steps`, and `validate_every_n_steps` are all _lower_ than typical for an SDXL LoRA training pipeline. The combination of masking with the small dataset size cause training to progress very quickly. These configuration fields were all adjusted accordingly to avoid overfitting. ## 5 - Start Training Launch the training run. ```bash # From inside the invoke-training/ source directory: invoke-train -c src/invoke_training/sample_configs/sdxl_lora_masks_gnome_1x24gb.yaml ``` Training takes ~30 mins on an NVIDIA RTX 4090. ## 4 - Monitor In a new terminal, launch Tensorboard to monitor the training run: ```bash tensorboard --logdir output/ ``` Access Tensorboard at [localhost:6006](http://localhost:6006) in your browser. Sample images will be logged to Tensorboard so that you can see how the model is evolving. Once training is complete, select the model checkpoint that produces the best visual results. For this tutorial, we'll use the checkpoint from step 300: ![Screenshot of the Tensorboard UI showing the validation images for step 300.](../../images/bruce_masks/bruce_masks_step_300.jpg) *Screenshot of the Tensorboard UI showing the validation images for epoch 300. The validation prompt was: "A stuffed gnome at the beach with a pina colada in its hand.".* ## 6 - Import into InvokeAI If you haven't already, setup [InvokeAI](https://github.com/invoke-ai/InvokeAI) by following its documentation. Import your trained LoRA model from the 'Models' tab. Congratulations, you can now use your new Bruce-the-Gnome model! 🎉 ================================================ FILE: docs/guides/stable_diffusion/robocats_finetune_sdxl.md ================================================ # Finetune - SDXL This tutorial explains how to do a full finetune training run on a [Stable Diffusion XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) base model. ## 0 - Prerequisites Full model finetuning is more compute-intensive than parameter-efficient finetuning alternatives (e.g. LoRA or Textual Inversion). This tutorial requires a minimum of 24GB of GPU VRAM. ## 1 - Dataset Preparation For this tutorial, we will use a dataset consisting of 14 images of robocats. The images were auto-captioned. Here are some sample images from the dataset, including their captions: | | | | - | - | | ![A white robot with blue eyes and a yellow nose sits on a rock, gazing at the camera, with a pink tree and a white cat in the background.](../../images/robocats/sipu3h70yb87rju8a8l36ejr.jpg) | ![A white cat with green eyes and a blue collar sits on a moss-covered rock in a forest, gazing directly at the camera.](../../images/robocats/v2h3ld50bi9owhhzo9gf9utg.jpg) | | *A white robot with blue eyes and a yellow nose sits on a rock, gazing at the camera, with a pink tree and a white cat in the background.* | *A white cat with green eyes and a blue collar sits on a moss-covered rock in a forest, gazing directly at the camera.* | ## 2 - Configuration Below is the training configuration that we'll use for this tutorial. Raw config file: [src/invoke_training/sample_configs/sdxl_finetune_robocats_1x24gb.yaml](https://github.com/invoke-ai/invoke-training/blob/main/src/invoke_training/sample_configs/sdxl_finetune_robocats_1x24gb.yaml). ```yaml title="sdxl_finetune_robocats_1x24gb.yaml" --8<-- "src/invoke_training/sample_configs/sdxl_finetune_robocats_1x24gb.yaml" ``` Full documentation of all of the configuration options is here: [Finetune SDXL Config](../../reference/config/pipelines/sdxl_finetune.md) !!! note "`save_checkpoint_format`" Note the `save_checkpoint_format` setting, as it is unique to full finetune training. For this tutorial, we have set `save_checkpoint_format: trained_only_diffusers`. This means that only the UNet model will be saved at each checkpoint, and it will be saved in diffusers format. This setting conserves disk space by not redundantly saving the non-trained weights. Before these UNet checkpoints can be used, they must either be merged into a full model, or extracted into a LoRA. Instructions for this follow later in this tutorial. A full explanation of the `save_checkpoint_format` options can be found here: [save_checkpoint_format][invoke_training.pipelines.stable_diffusion_xl.finetune.config.SdxlFinetuneConfig.save_checkpoint_format]. ## 3 - Start Training Launch the training run. ```bash # From inside the invoke-training/ source directory: invoke-train -c src/invoke_training/sample_configs/sdxl_finetune_robocats_1x24gb.yaml ``` Training takes ~45 mins on an NVIDIA RTX 4090. ## 4 - Monitor In a new terminal, launch Tensorboard to monitor the training run: ```bash tensorboard --logdir output/ ``` Access Tensorboard at [localhost:6006](http://localhost:6006) in your browser. Sample images will be logged to Tensorboard so that you can see how the model is evolving. Once training is complete, select the model checkpoint that produces the best visual results. ## 5 - Prepare the trained model Since we set `save_checkpoint_format: trained_only_diffusers`, our selected checkpoint only contains the UNet model weights. The checkpoint has the following directory structure: ```bash output/robocats/sdxl_finetune/1715373799.3558652/checkpoints/checkpoint-epoch_00000500-step_00002000/ └── unet ├── config.json └── diffusion_pytorch_model.safetensors ``` Before we can use this trained model, we must do one of the following: - Prepare a full diffusers checkpoint with the new UNet weights. - Extract the difference between the trained UNet and the original UNet into a LoRA model. ### Prepare a full model If we want to use our finetuned UNet model, we must first package it into a format supported by applications like InvokeAI. In this section we will assume that we have a full SDXL base model in diffusers format. It should have a directory structure like the one shown before. We simply need to replace the `unet/` directory with the one from our selected training checkpoint: ```bash stable-diffusion-xl-base-1.0 ├── model_index.json ├── scheduler │ └── scheduler_config.json ├── text_encoder │ ├── config.json │ └── model.fp16.safetensors ├── text_encoder_2 │ ├── config.json │ └── model.fp16.safetensors ├── tokenizer │ ├── merges.txt │ ├── special_tokens_map.json │ ├── tokenizer_config.json │ └── vocab.json ├── tokenizer_2 │ ├── merges.txt │ ├── special_tokens_map.json │ ├── tokenizer_config.json │ └── vocab.json ├── unet # <-- Replace this directory with the trained checkpoint. │ ├── config.json │ └── diffusion_pytorch_model.fp16.safetensors ├── vae │ ├── config.json │ └── diffusion_pytorch_model.fp16.safetensors └── vae_1_0 └── diffusion_pytorch_model.fp16.safetensors ``` !!! note "diffusers variants (e.g. 'fp16')" In this example, notice that the `*.safetensors` files contain `.fp16.` in their filenames. Hugging Face refers to this identifier as a "variant". It is used to select between multiple model variants in their model hub. In this case, we should add the `.fp16.` variant tag to our finetuned UNet for consistency with the rest of the model. Since we set `save_dtype: float16` in our training config, the `fp16` tag accurately represents the precision of our UNet model file. ### Extract a LoRA model An alternative to using the finetuned UNet model directly is to compare it against the original and extract the difference as a LoRA model. The resultant LoRA has a much smaller file size and can be applied to any base model. But, the LoRA model is a *lossy* representation of the difference, so some quality degradation is expected. To extract a LoRA model, run the following command: ```bash python src/invoke_training/model_merge/scripts/extract_lora_from_model_diff.py \ --model-type SDXL \ --model-orig path/to/stable-diffusion-xl-base-1.0 \ --model-tuned output/robocats/sdxl_finetune/1715373799.3558652/checkpoints/checkpoint-epoch_00000500-step_00002000 \ --save-to robocats_lora_step_2000.safetensors \ --lora-rank 32 ``` ## 6 - Import into InvokeAI If you haven't already, setup [InvokeAI](https://github.com/invoke-ai/InvokeAI) by following its documentation. Import your finetuned diffusers model or your extracted LoRA from the 'Models' tab. Congratulations, you can now use your new robocat model! 🎉 ## 7 - Comparison: Finetune vs. LoRA Extraction As noted earlier, the LoRA extraction process is lossy for a number of reasons. Below, we compare images generated with the same seed and prompt for 3 different model configurations. Prompt: *In robocat style, a robotic lion in the jungle.* | SDXL Base 1.0 | w/ Finetuned UNet | w/ Extracted LoRA | | - | - | - | | ![Image generated with SDXL Base 1.0. Prompt: In robocat style, a robotic lion in the jungle.](../../images/robocats/lion_base.jpg) | ![Image generated with finetuned UNet. Prompt: In robocat style, a robotic lion in the jungle.](../../images/robocats/lion_finetuned.jpg) | ![Image generated with extracted LoRA. Prompt: In robocat style, a robotic lion in the jungle.](../../images/robocats/lion_extracted_lora.jpg) ================================================ FILE: docs/guides/stable_diffusion/textual_inversion_sdxl.md ================================================ # Textual Inversion - SDXL This tutorial walks through a [Textual Inversion](https://arxiv.org/abs/2208.01618) training run with a [Stable Diffusion XL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) base model. ## 1 - Dataset For this tutorial, we'll use a dataset consisting of 4 images of Bruce the Gnome: | | | | - | - | | ![bruce_the_gnome dataset image 1.](../../images/bruce_the_gnome/001.jpg) | ![bruce_the_gnome dataset image 2.](../../images/bruce_the_gnome/002.jpg) | | ![bruce_the_gnome dataset image 3.](../../images/bruce_the_gnome/003.jpg) | ![bruce_the_gnome dataset image 4.](../../images/bruce_the_gnome/004.jpg) | This sample dataset is included in the invoke-training repo under [sample_data/bruce_the_gnome](https://github.com/invoke-ai/invoke-training/tree/main/sample_data/bruce_the_gnome). Here are a few tips for preparing a Textual Inversion dataset: - Aim for 4 to 50 images of your concept (object / style). The optimal number depends on many factors, and can be much higher than this for some use cases. - Vary all of the image features that you *don't* want your TI embedding to contain (e.g. background, pose, lighting, etc.). ## 2 - Configuration Below is the training configuration that we'll use for this tutorial. Raw config file: [src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml](https://github.com/invoke-ai/invoke-training/blob/main/src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml). Full config reference docs: [Textual Inversion SDXL Config](../../reference/config/pipelines/sdxl_textual_inversion.md) ```yaml title="sdxl_textual_inversion_gnome_1x24gb.yaml" --8<-- "src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml" ``` ## 3 - Start Training [Install invoke-training](../../get-started/installation.md), if you haven't already. Launch the Textual Inversion training pipeline: ```bash # From inside the invoke-training/ source directory: invoke-train -c src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml ``` Training takes ~40 mins on an NVIDIA RTX 4090. ## 4 - Monitor In a new terminal, launch Tensorboard to monitor the training run: ```bash tensorboard --logdir output/ ``` Access Tensorboard at [localhost:6006](http://localhost:6006) in your browser. Sample images will be logged to Tensorboard so that you can see how the Textual Inversion embedding is evolving. Once training is complete, select the epoch that produces the best visual results. For this tutorial, we'll choose epoch 500: ![Screenshot of the Tensorboard UI showing the validation images for epoch 500.](../../images/tensorboard_bruce_the_gnome_epoch_500.png) *Screenshot of the Tensorboard UI showing the validation images for epoch 500.* ## 5 - Transfer to InvokeAI If you haven't already, setup [InvokeAI](https://github.com/invoke-ai/InvokeAI) by following its documentation. Copy the selected TI embedding into your `${INVOKEAI_ROOT}/autoimport/embedding/` directory. For example: ```bash cp output/sdxl_ti_bruce_the_gnome/1702587511.2273068/checkpoint_epoch-00000500.safetensors ${INVOKEAI_ROOT}/autoimport/embedding/bruce_the_gnome.safetensors ``` Note that we renamed the file to `bruce_the_gnome.safetensors`. You can choose any file name, but this will become the token used to reference your embedding. So, in our case, we can refer to our new embedding by including `` in our prompts. Launch Invoke AI and you can now use your new `bruce_the_gnome` TI embedding! 🎉 ![Screenshot of the InvokeAI UI with an example of an image generated with the bruce_the_gnome TI embedding.](../../images/invokeai_bruce_the_gnome_ti.png) *Example image generated with the prompt "`a photo of at the park`".* ================================================ FILE: docs/index.md ================================================ # invoke-training A library for training custom Stable Diffusion models (fine-tuning, LoRA training, textual inversion, etc.) that can be used in [InvokeAI](https://github.com/invoke-ai/InvokeAI). ## Documentation The documentation is organized as follows: - [Get Started](get-started/installation.md): Install `invoke-training` and run your first training pipeline. - [Guides](guides/dataset_formats.md): Full tutorials for running popular training pipelines. - [Config Reference](reference/config/index.md): Reference documentation for all supported training configuration options. - [Contributing](contributing/development_environment.md): Information for `invoke-training` developers. ================================================ FILE: docs/reference/config/index.md ================================================ # Config Reference This section contains reference documentation for the `invoke-training` configuration schema (i.e. documentation for all of the supported training options). This documentation uses python typing semantics to define the configuration schema. Typically the configuration for a training run is specified in a YAML file and then parse against this schema. ================================================ FILE: docs/reference/config/pipelines/sd_lora.md ================================================ # `SdLoraConfig` ::: invoke_training.pipelines.stable_diffusion.lora.config.SdLoraConfig options: members: - type ::: invoke_training.pipelines.stable_diffusion.lora.config.SdLoraConfig options: filters: - "!^model_config" - "!^type" ================================================ FILE: docs/reference/config/pipelines/sd_textual_inversion.md ================================================ # `SdTextualInversionConfig` ::: invoke_training.pipelines.stable_diffusion.textual_inversion.config.SdTextualInversionConfig options: members: - type ::: invoke_training.pipelines.stable_diffusion.textual_inversion.config.SdTextualInversionConfig options: filters: - "!^model_config" - "!^type" ================================================ FILE: docs/reference/config/pipelines/sdxl_finetune.md ================================================ # `SdxlFinetuneConfig` ::: invoke_training.pipelines.stable_diffusion_xl.finetune.config.SdxlFinetuneConfig options: members: - type ::: invoke_training.pipelines.stable_diffusion_xl.finetune.config.SdxlFinetuneConfig options: filters: - "!^model_config" - "!^type" ================================================ FILE: docs/reference/config/pipelines/sdxl_lora.md ================================================ # `SdxlLoraConfig` ::: invoke_training.pipelines.stable_diffusion_xl.lora.config.SdxlLoraConfig options: members: - type ::: invoke_training.pipelines.stable_diffusion_xl.lora.config.SdxlLoraConfig options: filters: - "!^model_config" - "!^type" ================================================ FILE: docs/reference/config/pipelines/sdxl_lora_and_textual_inversion.md ================================================ # `SdxlLoraAndTextualInversionConfig` ::: invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config.SdxlLoraAndTextualInversionConfig options: members: - type ::: invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config.SdxlLoraAndTextualInversionConfig options: filters: - "!^model_config" - "!^type" ================================================ FILE: docs/reference/config/pipelines/sdxl_textual_inversion.md ================================================ # `SdxlTextualInversionConfig` Below is a sample yaml config file for Textual Inversion SDXL training ([raw file](https://github.com/invoke-ai/invoke-training/blob/main/src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml)). All of the configuration fields are explained in detail on this page. ```yaml title="sdxl_textual_inversion_gnome_1x24gb.yaml" --8<-- "src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml" ``` ::: invoke_training.pipelines.stable_diffusion_xl.textual_inversion.config.SdxlTextualInversionConfig options: members: - type ::: invoke_training.pipelines.stable_diffusion_xl.textual_inversion.config.SdxlTextualInversionConfig options: filters: - "!^model_config" - "!^type" ================================================ FILE: docs/reference/config/shared/data/data_loader_config.md ================================================ ::: invoke_training.config.data.data_loader_config options: filters: - "!^model_config" ================================================ FILE: docs/reference/config/shared/data/dataset_config.md ================================================ ::: invoke_training.config.data.dataset_config options: filters: - "!^model_config" ================================================ FILE: docs/reference/config/shared/optimizer_config.md ================================================ ::: invoke_training.config.optimizer.optimizer_config options: filters: - "!^model_config" ================================================ FILE: docs/templates/python/material/labels.html ================================================ ================================================ FILE: mkdocs.yml ================================================ site_name: invoke-training site_url: https://invoke-ai.github.io/invoke-training/ repo_name: invoke-ai/invoke-training repo_url: https://github.com/invoke-ai/invoke-training theme: name: material features: - navigation.tabs - navigation.indexes - navigation.sections - content.code.copy markdown_extensions: - admonition - sane_lists - pymdownx.highlight: anchor_linenums: true line_spans: __span pygments_lang_class: true - pymdownx.inlinehilite - pymdownx.snippets - pymdownx.superfences nav: - Welcome: index.md - Get Started: - get-started/installation.md - get-started/quick-start.md - Guides: - Dataset Formats: guides/dataset_formats.md - Model Merging: guides/model_merge.md - Stable Diffusion Training: - guides/stable_diffusion/robocats_finetune_sdxl.md - guides/stable_diffusion/gnome_lora_masks_sdxl.md - guides/stable_diffusion/textual_inversion_sdxl.md - guides/stable_diffusion/dpo_lora_sd.md - YAML Config Reference: - reference/config/index.md - pipelines: - SD LoRA Config: reference/config/pipelines/sd_lora.md - SD Textual Inversion Config: reference/config/pipelines/sd_textual_inversion.md - SDXL LoRA Config: reference/config/pipelines/sdxl_lora.md - SDXL Textual Inversion Config: reference/config/pipelines/sdxl_textual_inversion.md - SDXL LoRA and Textual Inversion Config: reference/config/pipelines/sdxl_lora_and_textual_inversion.md - SDXL Finetune Config: reference/config/pipelines/sdxl_finetune.md - shared: - data_loader_config: reference/config/shared/data/data_loader_config.md - dataset_config: reference/config/shared/data/dataset_config.md - optimizer_config: reference/config/shared/optimizer_config.md - Contributing: - contributing/development_environment.md - contributing/directory_structure.md - contributing/tests.md - contributing/documentation.md plugins: - search - mkdocstrings: default_handler: python custom_templates: docs/templates handlers: python: options: show_root_heading: false show_root_toc_entry: false show_bases: false show_source: false show_if_no_docstring: true inherited_members: true annotations_path: brief separate_signature: true show_signature_annotations: true members_order: source ================================================ FILE: pyproject.toml ================================================ [build-system] requires = ["setuptools>=65.5", "pip>=22.3"] build-backend = "setuptools.build_meta" [project] name = "invoke-training" version = "0.0.1" authors = [{ name = "The Invoke AI Team", email = "ryan@invoke.ai" }] description = "A library for Stable Diffusion model training." readme = "README.md" requires-python = ">=3.10" license = { text = "Apache-2.0" } classifiers = [ "Programming Language :: Python :: 3", "Operating System :: OS Independent", ] dependencies = [ "accelerate", "datasets~=2.14.3", "diffusers[torch]", "einops", "fastapi", "gradio", "invokeai>=5.10.0a1", "numpy<2.0.0", "omegaconf", "peft~=0.11.1", "pillow", "prodigyopt", "pydantic", "pyyaml", "safetensors", "tensorboard", "torch", "torchvision", "tqdm", "transformers", "uvicorn[standard]", ] [project.optional-dependencies] "xformers" = ["xformers>=0.0.28.post1; sys_platform!='darwin'"] "bitsandbytes" = ["bitsandbytes>=0.43.1; sys_platform!='darwin'"] "test" = [ "mkdocs", "mkdocs-material", "mkdocstrings[python]", "pre-commit~=3.3.3", "pytest~=7.4.0", "ruff~=0.11.2", "ruff-lsp", ] [project.scripts] "invoke-train" = "invoke_training.scripts.invoke_train:main" "invoke-train-ui" = "invoke_training.scripts.invoke_train_ui:main" "invoke-generate-images" = "invoke_training.scripts.invoke_generate_images:main" "invoke-visualize-data-loading" = "invoke_training.scripts.invoke_visualize_data_loading:main" [project.urls] "Homepage" = "https://github.com/invoke-ai/invoke-training" "Discord" = "https://discord.gg/ZmtBAhwWhy" [tool.setuptools.package-data] "invoke_training.assets" = ["*.png"] "invoke_training.sample_configs" = ["**/*.yaml"] "invoke_training.ui" = ["*.html"] [tool.ruff] src = ["src"] lint.select = ["E", "F", "W", "C9", "N8", "I"] target-version = "py39" line-length = 120 [tool.pytest.ini_options] addopts = "--strict-markers" markers = [ "cuda: marks tests that require a CUDA GPU", "loads_model: marks tests that require a model (or data) from the HF hub", ] ================================================ FILE: sample_data/bruce_the_gnome/data.jsonl ================================================ {"image": "001.png", "text": "A stuffed gnome sits on a wooden floor, facing right with a gray couch in the background."} {"image": "002.png", "text": "A stuffed gnome stands on a black tiled floor, with a silver refrigerator and white wall in the background."} {"image": "004.png", "text": "A stuffed gnome sits on a white marble floor, photorealistic."} {"image": "003.png", "text": "A stuffed gnome sits on a gray tiled floor, facing the camera."} ================================================ FILE: src/invoke_training/__init__.py ================================================ ================================================ FILE: src/invoke_training/_shared/__init__.py ================================================ ================================================ FILE: src/invoke_training/_shared/accelerator/__init__.py ================================================ ================================================ FILE: src/invoke_training/_shared/accelerator/accelerator_utils.py ================================================ import logging import os from typing import Literal import datasets import diffusers import torch import transformers from accelerate import Accelerator from accelerate.logging import MultiProcessAdapter, get_logger from accelerate.utils import ProjectConfiguration def initialize_accelerator( out_dir: str, gradient_accumulation_steps: int, mixed_precision: str, log_with: str ) -> Accelerator: """Configure Hugging Face accelerate and return an Accelerator. Args: out_dir (str): The output directory where results will be written. gradient_accumulation_steps (int): Forwarded to accelerat.Accelerator(...). mixed_precision (str): Forwarded to accelerate.Accelerator(...). log_with (str): Forwarded to accelerat.Accelerator(...) Returns: Accelerator """ accelerator_project_config = ProjectConfiguration( project_dir=out_dir, logging_dir=os.path.join(out_dir, "logs"), ) return Accelerator( project_config=accelerator_project_config, gradient_accumulation_steps=gradient_accumulation_steps, mixed_precision=mixed_precision, log_with=log_with, ) def initialize_logging(logger_name: str, accelerator: Accelerator) -> MultiProcessAdapter: """Configure logging. Returns an accelerate logger with multi-process logging support. Logging is configured to be more verbose on the main process. Non-main processes only log at error level for Hugging Face libraries (datasets, transformers, diffusers). Args: accelerator (Accelerator): The Accelerator to configure. Returns: MultiProcessAdapter: _description_ """ logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) if accelerator.is_local_main_process: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_warning() diffusers.utils.logging.set_verbosity_info() else: # Only log errors from non-main processes. datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() diffusers.utils.logging.set_verbosity_error() return get_logger(logger_name) def get_mixed_precision_dtype(accelerator: Accelerator): """Extract torch.dtype from Accelerator config. Args: accelerator (Accelerator): The Hugging Face Accelerator. Raises: NotImplementedError: If the accelerator's mixed_precision configuration is not recognized. Returns: torch.dtype: The weight type inferred from the accelerator mixed_precision configuration. """ weight_dtype: torch.dtype = torch.float32 if accelerator.mixed_precision is None or accelerator.mixed_precision == "no": weight_dtype = torch.float32 elif accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 else: raise NotImplementedError(f"mixed_precision mode '{accelerator.mixed_precision}' is not yet supported.") return weight_dtype def get_dtype_from_str(dtype_str: Literal["float16", "bfloat16", "float32"]) -> torch.dtype: if dtype_str == "float16": return torch.float16 elif dtype_str == "bfloat16": return torch.bfloat16 elif dtype_str == "float32": return torch.float32 else: raise ValueError(f"Unsupported dtype: {dtype_str}") ================================================ FILE: src/invoke_training/_shared/checkpoints/__init__.py ================================================ ================================================ FILE: src/invoke_training/_shared/checkpoints/checkpoint_tracker.py ================================================ import os import shutil import typing class CheckpointTracker: """A utility class for managing checkpoint paths. Manages checkpoint paths of the following forms: - Checkpoint directories: `{base_dir}/{prefix}-epoch_{num_epochs}-step_{num_steps}` - Checkpoint files: `{base_dir}/{prefix}-epoch_{num_epochs}-step_{num_steps}{extension}` """ def __init__( self, base_dir: str, prefix: str, extension: typing.Optional[str] = None, max_checkpoints: typing.Optional[int] = None, index_padding: int = 8, ): """Initialize a CheckpointTracker. Args: base_dir (str): The base checkpoint directory. prefix (str): A prefix applied to every checkpoint. extension (str, optional): If set, this is the file extension that will be applied to all checkpoints (usually one of ".pt", ".ckpt", or ".safetensors"). If None, then it will be assumed that we are managing checkpoint directories rather than files. max_checkpoints (typing.Optional[int], optional): The maximum number of checkpoints that should exist in base_dir. index_padding (int, optional): The length of the zero-padded epoch/step counts in the generated checkpoint names. E.g. index_padding=8 would produce checkpoint paths like "base_dir/prefix-epoch_00000001-step_00000001.ckpt". Raises: ValueError: If extension is provided, but it doesn not start with a '.'. """ if extension is not None and not extension.startswith("."): raise ValueError(f"extension='{extension}' must start with a '.'.") self._base_dir = base_dir self._prefix = prefix self._extension = extension self._max_checkpoints = max_checkpoints self._index_padding = index_padding def prune(self, buffer_num: int = 1) -> int: """Delete checkpoint files and directories so that there are at most `max_checkpoints - buffer_num` checkpoints remaining. The checkpoints with the lowest step counts will be deleted. Args: buffer_num (int, optional): The number below `max_checkpoints` to 'free-up'. Returns: int: The number of checkpoints deleted. """ if self._max_checkpoints is None: return 0 checkpoints = os.listdir(self._base_dir) checkpoints = [p for p in checkpoints if p.startswith(self._prefix)] checkpoints = sorted( checkpoints, key=lambda x: int(os.path.splitext(x)[0].split("-step_")[-1]), ) num_to_remove = len(checkpoints) - (self._max_checkpoints - buffer_num) if num_to_remove > 0: checkpoints_to_remove = checkpoints[:num_to_remove] for checkpoint_to_remove in checkpoints_to_remove: checkpoint_to_remove = os.path.join(self._base_dir, checkpoint_to_remove) if os.path.isfile(checkpoint_to_remove): # Delete checkpoint file. os.remove(checkpoint_to_remove) else: # Delete checkpoint directory. shutil.rmtree(checkpoint_to_remove) return max(0, num_to_remove) def get_path(self, epoch: int, step: int) -> str: """Get the checkpoint path for index `idx`. Args: epoch (int): The number of completed epochs. step (int): The number of completed training steps. Returns: str: The checkpoint path. """ suffix = self._extension or "" return os.path.join( self._base_dir, f"{self._prefix.strip()}-epoch_{epoch:0>{self._index_padding}}-step_{step:0>{self._index_padding}}{suffix}", ) ================================================ FILE: src/invoke_training/_shared/checkpoints/lora_checkpoint_utils.py ================================================ from pathlib import Path import peft import torch def save_multi_model_peft_checkpoint(checkpoint_dir: Path | str, models: dict[str, peft.PeftModel]): """Save a dict of PeftModels to a checkpoint directory. The `models` dict keys are used as the subdirectories for each individual model. `load_multi_model_peft_checkpoint(...)` can be used to load the resultant checkpoint. """ checkpoint_dir = Path(checkpoint_dir) for model_key, peft_model in models.items(): assert isinstance(peft_model, peft.PeftModel) # HACK(ryand): PeftModel.save_pretrained(...) expects the config to have a "_name_or_path" entry. For now, we # set this to None here. This should be fixed upstream in PEFT. if ( hasattr(peft_model, "config") and isinstance(peft_model.config, dict) and "_name_or_path" not in peft_model.config ): peft_model.config["_name_or_path"] = None peft_model.save_pretrained(str(checkpoint_dir / model_key)) def load_multi_model_peft_checkpoint( checkpoint_dir: Path | str, models: dict[str, torch.nn.Module], is_trainable: bool = False, raise_if_subdir_missing: bool = True, ) -> dict[str, torch.nn.Module]: """Load a multi-model PEFT checkpoint that was saved with `save_multi_model_peft_checkpoint(...)`.""" checkpoint_dir = Path(checkpoint_dir) assert checkpoint_dir.exists() out_models = {} for model_key, model in models.items(): dir_path: Path = checkpoint_dir / model_key if dir_path.exists(): out_models[model_key] = peft.PeftModel.from_pretrained(model, dir_path, is_trainable=is_trainable) else: if raise_if_subdir_missing: raise ValueError(f"'{dir_path}' does not exist.") else: # Pass through the model unchanged. out_models[model_key] = model return out_models # This implementation is based on # https://github.com/huggingface/peft/blob/8665e2b5719faa4e4b91749ddec09442927b53e0/examples/lora_dreambooth/convert_peft_sd_lora_to_kohya_ss.py#L20 def _convert_peft_state_dict_to_kohya_state_dict( lora_config: peft.LoraConfig, peft_state_dict: dict[str, torch.Tensor], prefix: str, dtype: torch.dtype, ) -> dict[str, torch.Tensor]: kohya_ss_state_dict = {} for peft_key, weight in peft_state_dict.items(): kohya_key = peft_key.replace("base_model.model", prefix) kohya_key = kohya_key.replace("lora_A", "lora_down") kohya_key = kohya_key.replace("lora_B", "lora_up") kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2) kohya_ss_state_dict[kohya_key] = weight.to(dtype) # Set alpha parameter if "lora_down" in kohya_key: alpha_key = f"{kohya_key.split('.')[0]}.alpha" kohya_ss_state_dict[alpha_key] = torch.tensor(lora_config.lora_alpha).to(dtype) return kohya_ss_state_dict def _convert_peft_models_to_kohya_state_dict( kohya_prefixes: list[str], models: list[peft.PeftModel] ) -> dict[str, torch.Tensor]: kohya_state_dict = {} default_adapter_name = "default" for kohya_prefix, peft_model in zip(kohya_prefixes, models, strict=True): lora_config = peft_model.peft_config[default_adapter_name] assert isinstance(lora_config, peft.LoraConfig) peft_state_dict = peft.get_peft_model_state_dict(peft_model, adapter_name=default_adapter_name) kohya_state_dict.update( _convert_peft_state_dict_to_kohya_state_dict( lora_config=lora_config, peft_state_dict=peft_state_dict, prefix=kohya_prefix, dtype=torch.float32, ) ) return kohya_state_dict ================================================ FILE: src/invoke_training/_shared/checkpoints/serialization.py ================================================ import typing from pathlib import Path import safetensors.torch import torch def save_state_dict(state_dict: typing.Dict[str, torch.Tensor], out_file: typing.Union[Path, str]): """Save a state_dict to a file. Both safetensors and torch formats are supported. The format is inferred from the `out_file` extension. Supported extensions: - ".ckpt" -> torch - ".pt" -> torch - ".safetensors -> safetensors Args: state_dict (typing.Dict[str, torch.Tensor]): The state_dict to save. out_file (Path | str): The output file to save to. Raises: ValueError: If the `out_file` has an unsupported file extension. """ out_file = Path(out_file) if out_file.suffix == ".ckpt" or out_file.suffix == ".pt": torch.save(state_dict, out_file) elif out_file.suffix == ".safetensors": safetensors.torch.save_file(state_dict, out_file) else: raise ValueError(f"Unsupported file extension: '{out_file.suffix}'.") def load_state_dict(in_file: typing.Union[Path, str]) -> typing.Dict[str, torch.Tensor]: """Load a state_dict from a file. Both safetensors and torch formats are supported. The format is inferred from the `in_file` extension. Supported extensions: - ".ckpt" -> torch - ".pt" -> torch - ".safetensors -> safetensors Args: in_file (Path | str): The input file to load from. Raises: ValueError: If the `in_file` has an unsupported file extension. Returns: typing.Dict[str, torch.Tensor]: The loaded state_dict. """ in_file = Path(in_file) if in_file.suffix == ".ckpt" or in_file.suffix == ".pt": return torch.load(in_file) elif in_file.suffix == ".safetensors": return safetensors.torch.load_file(in_file) else: raise ValueError(f"Unsupported file extension: '{in_file.suffix}'.") ================================================ FILE: src/invoke_training/_shared/data/ARCHITECTURE.md ================================================ # Dataset Architecture Dataset handling is split into 3 layers of abstraction: Datasets, Transforms, and DataLoaders. Each is explained in more detail below. ## Datasets Datasets implement the [torch.utils.data.Dataset](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#creating-a-custom-dataset-for-your-files) interface. Most dataset classes act as an abstraction over a specific dataset format. ## Transforms Transforms are functions applied to data loaded by Datasets. For example, the `SDImageTransform` implements image augmentations for Stable Diffusion training. Transforms are kept separate from the underlying datasets for several reasons: - It is easier to write tests for isolated transforms. - Modular transforms can often be re-used for multiple base datasets. - Modular transforms make it easy to customize datasets for different situations. For example, you may want to wrap a dataset with one set of transforms initially to populate a cache, and then with a different set of transforms to read from the cache. Transforms are applied to a dataset via the `TransformDataset` class. ## DataLoaders The dataset classes (with composed transforms) are wrapped in a `torch.utils.data.DataLoader` that handles batch collation, multi-processing, etc. ================================================ FILE: src/invoke_training/_shared/data/__init__.py ================================================ ================================================ FILE: src/invoke_training/_shared/data/data_loaders/__init__.py ================================================ ================================================ FILE: src/invoke_training/_shared/data/data_loaders/dreambooth_sd_dataloader.py ================================================ import typing from torch.utils.data import ConcatDataset, DataLoader from torch.utils.data.sampler import RandomSampler, SequentialSampler from invoke_training._shared.data.data_loaders.image_caption_sd_dataloader import ( build_aspect_ratio_bucket_manager, sd_image_caption_collate_fn, ) from invoke_training._shared.data.datasets.image_dir_dataset import ImageDirDataset from invoke_training._shared.data.datasets.transform_dataset import TransformDataset from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import AspectRatioBucketBatchSampler from invoke_training._shared.data.samplers.batch_offset_sampler import BatchOffsetSampler from invoke_training._shared.data.samplers.concat_sampler import ConcatSampler from invoke_training._shared.data.samplers.interleaved_sampler import InterleavedSampler from invoke_training._shared.data.samplers.offset_sampler import OffsetSampler from invoke_training._shared.data.transforms.constant_field_transform import ConstantFieldTransform from invoke_training._shared.data.transforms.drop_field_transform import DropFieldTransform from invoke_training._shared.data.transforms.load_cache_transform import LoadCacheTransform from invoke_training._shared.data.transforms.sd_image_transform import SDImageTransform from invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache from invoke_training.config.data.data_loader_config import DreamboothSDDataLoaderConfig def build_dreambooth_sd_dataloader( config: DreamboothSDDataLoaderConfig, batch_size: int, text_encoder_output_cache_dir: typing.Optional[str] = None, text_encoder_cache_field_to_output_field: typing.Optional[dict[str, str]] = None, vae_output_cache_dir: typing.Optional[str] = None, shuffle: bool = True, sequential_batching: bool = False, ) -> DataLoader: """Construct a DataLoader for a DreamBooth dataset for Stable Diffusion XL. Args: config (DreamboothSDDataLoaderConfig): batch_size (int): text_encoder_output_cache_dir (str, optional): The directory where text encoder outputs are cached and should be loaded from. vae_output_cache_dir (str, optional): The directory where VAE outputs are cached and should be loaded from. If set, then the image augmentation transforms will be skipped, and the image will not be copied to VRAM. shuffle (bool, optional): Whether to shuffle the dataset order. sequential_batching (bool, optional): If True, the internal dataset will be processed sequentially rather than interleaving class and instance examples. This is intended to be used when processing the entire dataset for caching purposes. Defaults to False. Returns: DataLoader """ # Prepare instance dataset. base_instance_dataset = ImageDirDataset( config.instance_dataset.dataset_dir, id_prefix="instance_", keep_in_memory=config.instance_dataset.keep_in_memory, ) instance_dataset = TransformDataset( base_instance_dataset, [ ConstantFieldTransform("caption", config.instance_caption), ConstantFieldTransform("loss_weight", 1.0), ], ) datasets = [instance_dataset] # Prepare class dataset. base_class_dataset = None class_dataset = None if config.class_dataset is not None: base_class_dataset = ImageDirDataset( config.class_dataset.dataset_dir, id_prefix="class_", keep_in_memory=config.class_dataset.keep_in_memory ) class_dataset = TransformDataset( base_class_dataset, [ ConstantFieldTransform("caption", config.class_caption), ConstantFieldTransform("loss_weight", config.class_data_loss_weight), ], ) datasets.append(class_dataset) # Merge instance dataset and class dataset. merged_dataset = ConcatDataset(datasets) # Initialize either the fixed target resolution or aspect ratio buckets. target_resolution = None aspect_ratio_bucket_manager = None instance_sampler = None class_sampler = None if config.aspect_ratio_buckets is None: target_resolution = config.resolution # TODO(ryand): Provide a seeded generator. instance_sampler = RandomSampler(instance_dataset) if shuffle else SequentialSampler(instance_dataset) if base_class_dataset is not None: class_sampler = RandomSampler(class_dataset) if shuffle else SequentialSampler(class_dataset) class_sampler = OffsetSampler(class_sampler, offset=len(base_instance_dataset)) else: aspect_ratio_bucket_manager = build_aspect_ratio_bucket_manager(config=config.aspect_ratio_buckets) # TODO(ryand): Drill-down the seed parameter rather than hard-coding to 0 here. instance_sampler = AspectRatioBucketBatchSampler.from_image_sizes( bucket_manager=aspect_ratio_bucket_manager, image_sizes=base_instance_dataset.get_image_dimensions(), batch_size=batch_size, shuffle=shuffle, seed=0, ) if base_class_dataset is not None: class_sampler = AspectRatioBucketBatchSampler.from_image_sizes( bucket_manager=aspect_ratio_bucket_manager, image_sizes=base_class_dataset.get_image_dimensions(), batch_size=batch_size, shuffle=shuffle, seed=0, ) class_sampler = BatchOffsetSampler(class_sampler, offset=len(base_instance_dataset)) # Add transforms to the merged dataset. all_transforms = [] if vae_output_cache_dir is None: all_transforms.append( SDImageTransform( image_field_names=["image"], fields_to_normalize_to_range_minus_one_to_one=["image"], resolution=target_resolution, aspect_ratio_bucket_manager=aspect_ratio_bucket_manager, center_crop=config.center_crop, random_flip=config.random_flip, ) ) else: vae_cache = TensorDiskCache(vae_output_cache_dir) all_transforms.append( LoadCacheTransform( cache=vae_cache, cache_key_field="id", cache_field_to_output_field={ "vae_output": "vae_output", "original_size_hw": "original_size_hw", "crop_top_left_yx": "crop_top_left_yx", }, ) ) # We drop the image to avoid having to either convert from PIL, or handle PIL batch collation. all_transforms.append(DropFieldTransform("image")) if text_encoder_output_cache_dir is not None: assert text_encoder_cache_field_to_output_field is not None text_encoder_cache = TensorDiskCache(text_encoder_output_cache_dir) all_transforms.append( LoadCacheTransform( cache=text_encoder_cache, cache_key_field="id", cache_field_to_output_field=text_encoder_cache_field_to_output_field, ) ) merged_dataset = TransformDataset(merged_dataset, all_transforms) # Choose between sequential vs. interleaved merging of the instance and class samplers. # Sequential sampling is typically used to populate a cache, because it guarantees that all examples will be # included in an epoch. samplers = [instance_sampler] if class_sampler is not None: samplers.append(class_sampler) if sequential_batching: sampler = ConcatSampler(samplers) else: sampler = InterleavedSampler(samplers) if config.aspect_ratio_buckets is None: return DataLoader( merged_dataset, sampler=sampler, collate_fn=sd_image_caption_collate_fn, batch_size=batch_size, num_workers=config.dataloader_num_workers, ) else: # If config.aspect_ratio_buckets is not None, then we are using a batch sampler. return DataLoader( merged_dataset, batch_sampler=sampler, collate_fn=sd_image_caption_collate_fn, num_workers=config.dataloader_num_workers, ) ================================================ FILE: src/invoke_training/_shared/data/data_loaders/image_caption_flux_dataloader.py ================================================ import typing from torch.utils.data import DataLoader from invoke_training._shared.data.data_loaders.image_caption_sd_dataloader import ( build_aspect_ratio_bucket_manager, ) from invoke_training._shared.data.data_loaders.image_caption_sd_dataloader import ( sd_image_caption_collate_fn as flux_image_caption_collate_fn, ) from invoke_training._shared.data.datasets.build_dataset import ( build_hf_hub_image_caption_dataset, build_image_caption_dir_dataset, build_image_caption_jsonl_dataset, ) from invoke_training._shared.data.datasets.transform_dataset import TransformDataset from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import ( AspectRatioBucketBatchSampler, ) from invoke_training._shared.data.transforms.caption_prefix_transform import CaptionPrefixTransform from invoke_training._shared.data.transforms.drop_field_transform import DropFieldTransform from invoke_training._shared.data.transforms.flux_image_transform import FluxImageTransform from invoke_training._shared.data.transforms.load_cache_transform import LoadCacheTransform from invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache from invoke_training.config.data.data_loader_config import ImageCaptionFluxDataLoaderConfig from invoke_training.config.data.dataset_config import ( HFHubImageCaptionDatasetConfig, ImageCaptionDirDatasetConfig, ImageCaptionJsonlDatasetConfig, ) def build_image_caption_flux_dataloader( # noqa: C901 config: ImageCaptionFluxDataLoaderConfig, batch_size: int, use_masks: bool = False, text_encoder_output_cache_dir: typing.Optional[str] = None, text_encoder_cache_field_to_output_field: typing.Optional[dict[str, str]] = None, vae_output_cache_dir: typing.Optional[str] = None, shuffle: bool = True, ) -> DataLoader: """Construct a DataLoader for an image-caption dataset for Flux.1-dev. Args: config (ImageCaptionFluxDataLoaderConfig): The dataset config. batch_size (int): The DataLoader batch size. text_encoder_output_cache_dir (str, optional): The directory where text encoder outputs are cached and should be loaded from. If set, then the TokenizeTransform will not be applied. vae_output_cache_dir (str, optional): The directory where VAE outputs are cached and should be loaded from. If set, then the image augmentation transforms will be skipped, and the image will not be copied to VRAM. shuffle (bool, optional): Whether to shuffle the dataset order. Returns: DataLoader """ if isinstance(config.dataset, HFHubImageCaptionDatasetConfig): base_dataset = build_hf_hub_image_caption_dataset(config.dataset) elif isinstance(config.dataset, ImageCaptionJsonlDatasetConfig): base_dataset = build_image_caption_jsonl_dataset(config.dataset) elif isinstance(config.dataset, ImageCaptionDirDatasetConfig): base_dataset = build_image_caption_dir_dataset(config.dataset) else: raise ValueError(f"Unexpected dataset config type: '{type(config.dataset)}'.") # Initialize either the fixed target resolution or aspect ratio buckets. if config.aspect_ratio_buckets is None: aspect_ratio_bucket_manager = None batch_sampler = None else: aspect_ratio_bucket_manager = build_aspect_ratio_bucket_manager(config=config.aspect_ratio_buckets) # TODO(ryand): Drill-down the seed parameter rather than hard-coding to 0 here. batch_sampler = AspectRatioBucketBatchSampler.from_image_sizes( bucket_manager=aspect_ratio_bucket_manager, image_sizes=base_dataset.get_image_dimensions(), batch_size=batch_size, shuffle=shuffle, seed=0, ) all_transforms = [] if config.caption_prefix is not None: all_transforms.append(CaptionPrefixTransform(caption_field_name="caption", prefix=config.caption_prefix + " ")) if vae_output_cache_dir is None: image_field_names = ["image"] if use_masks: image_field_names.append("mask") else: all_transforms.append(DropFieldTransform("mask")) all_transforms.append( FluxImageTransform( image_field_names=image_field_names, fields_to_normalize_to_range_minus_one_to_one=["image"], resolution=config.resolution, aspect_ratio_bucket_manager=aspect_ratio_bucket_manager, center_crop=config.center_crop, random_flip=config.random_flip, ) ) else: # We drop the image to avoid having to either convert from PIL, or handle PIL batch collation. all_transforms.append(DropFieldTransform("image")) all_transforms.append(DropFieldTransform("mask")) vae_cache = TensorDiskCache(vae_output_cache_dir) cache_field_to_output_field = { "vae_output": "vae_output", "original_size_hw": "original_size_hw", "crop_top_left_yx": "crop_top_left_yx", } if use_masks: cache_field_to_output_field["mask"] = "mask" all_transforms.append( LoadCacheTransform( cache=vae_cache, cache_key_field="id", cache_field_to_output_field=cache_field_to_output_field, ) ) if text_encoder_output_cache_dir is not None: assert text_encoder_cache_field_to_output_field is not None text_encoder_cache = TensorDiskCache(text_encoder_output_cache_dir) all_transforms.append( LoadCacheTransform( cache=text_encoder_cache, cache_key_field="id", cache_field_to_output_field=text_encoder_cache_field_to_output_field, ) ) dataset = TransformDataset(base_dataset, all_transforms) if batch_sampler is None: return DataLoader( dataset, shuffle=shuffle, collate_fn=flux_image_caption_collate_fn, batch_size=batch_size, num_workers=config.dataloader_num_workers, ) else: return DataLoader( dataset, batch_sampler=batch_sampler, collate_fn=flux_image_caption_collate_fn, num_workers=config.dataloader_num_workers, ) ================================================ FILE: src/invoke_training/_shared/data/data_loaders/image_caption_sd_dataloader.py ================================================ import typing import torch from torch.utils.data import DataLoader from invoke_training._shared.data.datasets.build_dataset import ( build_hf_hub_image_caption_dataset, build_image_caption_dir_dataset, build_image_caption_jsonl_dataset, ) from invoke_training._shared.data.datasets.transform_dataset import TransformDataset from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import AspectRatioBucketBatchSampler from invoke_training._shared.data.transforms.caption_prefix_transform import CaptionPrefixTransform from invoke_training._shared.data.transforms.drop_field_transform import DropFieldTransform from invoke_training._shared.data.transforms.load_cache_transform import LoadCacheTransform from invoke_training._shared.data.transforms.sd_image_transform import SDImageTransform from invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache from invoke_training._shared.data.utils.aspect_ratio_bucket_manager import AspectRatioBucketManager from invoke_training.config.data.data_loader_config import AspectRatioBucketConfig, ImageCaptionSDDataLoaderConfig from invoke_training.config.data.dataset_config import ( HFHubImageCaptionDatasetConfig, ImageCaptionDirDatasetConfig, ImageCaptionJsonlDatasetConfig, ) def sd_image_caption_collate_fn(examples): """A batch collation function for the image-caption SDXL data loader.""" out_examples = { "id": [example["id"] for example in examples], } if "image" in examples[0]: out_examples["image"] = torch.stack([example["image"] for example in examples]) if "original_size_hw" in examples[0]: out_examples["original_size_hw"] = [example["original_size_hw"] for example in examples] if "crop_top_left_yx" in examples[0]: out_examples["crop_top_left_yx"] = [example["crop_top_left_yx"] for example in examples] if "caption" in examples[0]: out_examples["caption"] = [example["caption"] for example in examples] if "loss_weight" in examples[0]: out_examples["loss_weight"] = torch.tensor([example["loss_weight"] for example in examples]) if "prompt_embeds" in examples[0]: out_examples["prompt_embeds"] = torch.stack([example["prompt_embeds"] for example in examples]) out_examples["pooled_prompt_embeds"] = torch.stack([example["pooled_prompt_embeds"] for example in examples]) if "text_encoder_output" in examples[0]: out_examples["text_encoder_output"] = torch.stack([example["text_encoder_output"] for example in examples]) if "vae_output" in examples[0]: out_examples["vae_output"] = torch.stack([example["vae_output"] for example in examples]) if "mask" in examples[0]: out_examples["mask"] = torch.stack([example["mask"] for example in examples]) return out_examples def build_aspect_ratio_bucket_manager(config: AspectRatioBucketConfig): return AspectRatioBucketManager.from_constraints( target_resolution=config.target_resolution, start_dim=config.start_dim, end_dim=config.end_dim, divisible_by=config.divisible_by, ) def build_image_caption_sd_dataloader( # noqa: C901 config: ImageCaptionSDDataLoaderConfig, batch_size: int, use_masks: bool = False, text_encoder_output_cache_dir: typing.Optional[str] = None, text_encoder_cache_field_to_output_field: typing.Optional[dict[str, str]] = None, vae_output_cache_dir: typing.Optional[str] = None, shuffle: bool = True, ) -> DataLoader: """Construct a DataLoader for an image-caption dataset for Stable Diffusion XL. Args: config (ImageCaptionSDDataLoaderConfig): The dataset config. batch_size (int): The DataLoader batch size. text_encoder_output_cache_dir (str, optional): The directory where text encoder outputs are cached and should be loaded from. If set, then the TokenizeTransform will not be applied. vae_output_cache_dir (str, optional): The directory where VAE outputs are cached and should be loaded from. If set, then the image augmentation transforms will be skipped, and the image will not be copied to VRAM. shuffle (bool, optional): Whether to shuffle the dataset order. Returns: DataLoader """ if isinstance(config.dataset, HFHubImageCaptionDatasetConfig): base_dataset = build_hf_hub_image_caption_dataset(config.dataset) elif isinstance(config.dataset, ImageCaptionJsonlDatasetConfig): base_dataset = build_image_caption_jsonl_dataset(config.dataset) elif isinstance(config.dataset, ImageCaptionDirDatasetConfig): base_dataset = build_image_caption_dir_dataset(config.dataset) else: raise ValueError(f"Unexpected dataset config type: '{type(config.dataset)}'.") # Initialize either the fixed target resolution or aspect ratio buckets. if config.aspect_ratio_buckets is None: target_resolution = config.resolution aspect_ratio_bucket_manager = None batch_sampler = None else: target_resolution = None aspect_ratio_bucket_manager = build_aspect_ratio_bucket_manager(config=config.aspect_ratio_buckets) # TODO(ryand): Drill-down the seed parameter rather than hard-coding to 0 here. batch_sampler = AspectRatioBucketBatchSampler.from_image_sizes( bucket_manager=aspect_ratio_bucket_manager, image_sizes=base_dataset.get_image_dimensions(), batch_size=batch_size, shuffle=shuffle, seed=0, ) all_transforms = [] if config.caption_prefix is not None: all_transforms.append(CaptionPrefixTransform(caption_field_name="caption", prefix=config.caption_prefix + " ")) if vae_output_cache_dir is None: image_field_names = ["image"] if use_masks: image_field_names.append("mask") else: all_transforms.append(DropFieldTransform("mask")) all_transforms.append( SDImageTransform( image_field_names=image_field_names, fields_to_normalize_to_range_minus_one_to_one=["image"], resolution=target_resolution, aspect_ratio_bucket_manager=aspect_ratio_bucket_manager, center_crop=config.center_crop, random_flip=config.random_flip, ) ) else: # We drop the image to avoid having to either convert from PIL, or handle PIL batch collation. all_transforms.append(DropFieldTransform("image")) all_transforms.append(DropFieldTransform("mask")) vae_cache = TensorDiskCache(vae_output_cache_dir) cache_field_to_output_field = { "vae_output": "vae_output", "original_size_hw": "original_size_hw", "crop_top_left_yx": "crop_top_left_yx", } if use_masks: cache_field_to_output_field["mask"] = "mask" all_transforms.append( LoadCacheTransform( cache=vae_cache, cache_key_field="id", cache_field_to_output_field=cache_field_to_output_field, ) ) if text_encoder_output_cache_dir is not None: assert text_encoder_cache_field_to_output_field is not None text_encoder_cache = TensorDiskCache(text_encoder_output_cache_dir) all_transforms.append( LoadCacheTransform( cache=text_encoder_cache, cache_key_field="id", cache_field_to_output_field=text_encoder_cache_field_to_output_field, ) ) dataset = TransformDataset(base_dataset, all_transforms) if batch_sampler is None: return DataLoader( dataset, shuffle=shuffle, collate_fn=sd_image_caption_collate_fn, batch_size=batch_size, num_workers=config.dataloader_num_workers, ) else: return DataLoader( dataset, batch_sampler=batch_sampler, collate_fn=sd_image_caption_collate_fn, num_workers=config.dataloader_num_workers, ) ================================================ FILE: src/invoke_training/_shared/data/data_loaders/image_pair_preference_sd_dataloader.py ================================================ import typing import torch from torch.utils.data import DataLoader from invoke_training._shared.data.datasets.build_dataset import build_hf_image_pair_preference_dataset from invoke_training._shared.data.datasets.image_pair_preference_dataset import ImagePairPreferenceDataset from invoke_training._shared.data.datasets.transform_dataset import TransformDataset from invoke_training._shared.data.transforms.load_cache_transform import LoadCacheTransform from invoke_training._shared.data.transforms.sd_image_transform import SDImageTransform from invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache from invoke_training.pipelines._experimental.sd_dpo_lora.config import ImagePairPreferenceSDDataLoaderConfig def sd_image_pair_preference_collate_fn(examples): """A batch collation function.""" stack_keys = {"image_0", "image_1", "prompt_embeds", "pooled_prompt_embeds", "text_encoder_output", "vae_output"} list_keys = { "id", "original_size_hw_0", "original_size_hw_1", "crop_top_left_yx_0", "crop_top_left_yx_1", "prefer_0", "prefer_1", "caption", } unhandled_keys = set(examples[0].keys()) - (stack_keys | list_keys) if len(unhandled_keys) > 0: raise ValueError(f"The following keys are not handled by the collate function: {unhandled_keys}.") out_examples = {} # torch.stack(...) for k in stack_keys: if k in examples[0]: out_examples[k] = torch.stack([example[k] for example in examples]) # Basic list. for k in list_keys: if k in examples[0]: out_examples[k] = [example[k] for example in examples] return out_examples def build_image_pair_preference_sd_dataloader( config: ImagePairPreferenceSDDataLoaderConfig, batch_size: int, text_encoder_output_cache_dir: typing.Optional[str] = None, text_encoder_cache_field_to_output_field: typing.Optional[dict[str, str]] = None, vae_output_cache_dir: typing.Optional[str] = None, shuffle: bool = True, ) -> DataLoader: """Construct a DataLoader for an image-caption dataset for Stable Diffusion XL. Args: config (ImageCaptionSDDataLoaderConfig): The dataset config. batch_size (int): The DataLoader batch size. text_encoder_output_cache_dir (str, optional): The directory where text encoder outputs are cached and should be loaded from. If set, then the TokenizeTransform will not be applied. vae_output_cache_dir (str, optional): The directory where VAE outputs are cached and should be loaded from. If set, then the image augmentation transforms will be skipped, and the image will not be copied to VRAM. shuffle (bool, optional): Whether to shuffle the dataset order. Returns: DataLoader """ if config.dataset.type == "HF_HUB_IMAGE_PAIR_PREFERENCE_DATASET": base_dataset = build_hf_image_pair_preference_dataset(config=config.dataset) elif config.dataset.type == "IMAGE_PAIR_PREFERENCE_DATASET": base_dataset = ImagePairPreferenceDataset(dataset_dir=config.dataset.dataset_dir) else: raise ValueError(f"Unexpected dataset config type: '{type(config.dataset)}'.") target_resolution = config.resolution all_transforms = [] if vae_output_cache_dir is None: # TODO(ryand): Should I process both images in a single SDImageTransform so that they undergo the same # transformations? all_transforms.append( SDImageTransform( image_field_names=["image_0"], fields_to_normalize_to_range_minus_one_to_one=["image_0"], resolution=target_resolution, aspect_ratio_bucket_manager=None, center_crop=config.center_crop, random_flip=config.random_flip, orig_size_field_name="original_size_hw_0", crop_field_name="crop_top_left_yx_0", ) ) all_transforms.append( SDImageTransform( image_field_names=["image_1"], fields_to_normalize_to_range_minus_one_to_one=["image_1"], resolution=target_resolution, aspect_ratio_bucket_manager=None, center_crop=config.center_crop, random_flip=config.random_flip, orig_size_field_name="original_size_hw_1", crop_field_name="crop_top_left_yx_1", ) ) else: raise NotImplementedError("VAE caching is not yet implemented.") # vae_cache = TensorDiskCache(vae_output_cache_dir) # all_transforms.append( # LoadCacheTransform( # cache=vae_cache, # cache_key_field="id", # cache_field_to_output_field={ # "vae_output": "vae_output", # "original_size_hw": "original_size_hw", # "crop_top_left_yx": "crop_top_left_yx", # }, # ) # ) # # We drop the image to avoid having to either convert from PIL, or handle PIL batch collation. # all_transforms.append(DropFieldTransform("image")) if text_encoder_output_cache_dir is not None: assert text_encoder_cache_field_to_output_field is not None text_encoder_cache = TensorDiskCache(text_encoder_output_cache_dir) all_transforms.append( LoadCacheTransform( cache=text_encoder_cache, cache_key_field="id", cache_field_to_output_field=text_encoder_cache_field_to_output_field, ) ) dataset = TransformDataset(base_dataset, all_transforms) return DataLoader( dataset, shuffle=shuffle, collate_fn=sd_image_pair_preference_collate_fn, batch_size=batch_size, num_workers=config.dataloader_num_workers, ) ================================================ FILE: src/invoke_training/_shared/data/data_loaders/textual_inversion_sd_dataloader.py ================================================ from typing import Literal, Optional from torch.utils.data import DataLoader from invoke_training._shared.data.data_loaders.image_caption_sd_dataloader import ( build_aspect_ratio_bucket_manager, sd_image_caption_collate_fn, ) from invoke_training._shared.data.datasets.build_dataset import ( build_hf_hub_image_caption_dataset, build_image_caption_dir_dataset, build_image_caption_jsonl_dataset, ) from invoke_training._shared.data.datasets.image_dir_dataset import ImageDirDataset from invoke_training._shared.data.datasets.transform_dataset import TransformDataset from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import AspectRatioBucketBatchSampler from invoke_training._shared.data.transforms.concat_fields_transform import ConcatFieldsTransform from invoke_training._shared.data.transforms.drop_field_transform import DropFieldTransform from invoke_training._shared.data.transforms.load_cache_transform import LoadCacheTransform from invoke_training._shared.data.transforms.sd_image_transform import SDImageTransform from invoke_training._shared.data.transforms.shuffle_caption_transform import ShuffleCaptionTransform from invoke_training._shared.data.transforms.template_caption_transform import TemplateCaptionTransform from invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache from invoke_training.config.data.data_loader_config import TextualInversionSDDataLoaderConfig from invoke_training.config.data.dataset_config import ( HFHubImageCaptionDatasetConfig, ImageCaptionDirDatasetConfig, ImageCaptionJsonlDatasetConfig, ImageDirDatasetConfig, ) def get_preset_ti_caption_templates(preset: Literal["object", "style"]) -> list[str]: if preset == "object": return [ "a photo of a {}", "a rendering of a {}", "a cropped photo of the {}", "the photo of a {}", "a photo of a clean {}", "a photo of a dirty {}", "a dark photo of the {}", "a photo of my {}", "a photo of the cool {}", "a close-up photo of a {}", "a bright photo of the {}", "a cropped photo of a {}", "a photo of the {}", "a good photo of the {}", "a photo of one {}", "a close-up photo of the {}", "a rendition of the {}", "a photo of the clean {}", "a rendition of a {}", "a photo of a nice {}", "a good photo of a {}", "a photo of the nice {}", "a photo of the small {}", "a photo of the weird {}", "a photo of the large {}", "a photo of a cool {}", "a photo of a small {}", ] elif preset == "style": return [ "a painting in the style of {}", "a rendering in the style of {}", "a cropped painting in the style of {}", "the painting in the style of {}", "a clean painting in the style of {}", "a dirty painting in the style of {}", "a dark painting in the style of {}", "a picture in the style of {}", "a cool painting in the style of {}", "a close-up painting in the style of {}", "a bright painting in the style of {}", "a good painting in the style of {}", "a close-up painting in the style of {}", "a rendition in the style of {}", "a nice painting in the style of {}", "a small painting in the style of {}", "a weird painting in the style of {}", "a large painting in the style of {}", "a photo in the style of {}", "an image in the style of {}", "a drawing in the style of {}", "a sketch in the style of {}", "a digital work in the style of {}", "a digital rendering in the style of {}", "a photograph in the style of {}", "photography in the style of {}", ] else: raise ValueError(f"Unrecognized learnable property type: '{preset}'.") def build_textual_inversion_sd_dataloader( # noqa: C901 config: TextualInversionSDDataLoaderConfig, placeholder_token: str, batch_size: int, use_masks: bool = False, vae_output_cache_dir: Optional[str] = None, shuffle: bool = True, ) -> DataLoader: """Construct a DataLoader for a Textual Inversion dataset for Stable Diffusion. Args: config (TextualInversionSDDataLoaderConfig): The dataset config. placeholder_token (str): The placeholder token being trained. batch_size (int): The DataLoader batch size. vae_output_cache_dir (str, optional): The directory where VAE outputs are cached and should be loaded from. If set, then the image augmentation transforms will be skipped, and the image will not be copied to VRAM. shuffle (bool, optional): Whether to shuffle the dataset order. Returns: DataLoader """ if isinstance(config.dataset, HFHubImageCaptionDatasetConfig): base_dataset = build_hf_hub_image_caption_dataset(config.dataset) elif isinstance(config.dataset, ImageCaptionJsonlDatasetConfig): base_dataset = build_image_caption_jsonl_dataset(config.dataset) elif isinstance(config.dataset, ImageCaptionDirDatasetConfig): base_dataset = build_image_caption_dir_dataset(config.dataset) elif isinstance(config.dataset, ImageDirDatasetConfig): base_dataset = ImageDirDataset( image_dir=config.dataset.dataset_dir, keep_in_memory=config.dataset.keep_in_memory ) else: raise ValueError(f"Unexpected dataset config type: '{type(config.dataset)}'.") # Initialize either the fixed target resolution or aspect ratio buckets. if config.aspect_ratio_buckets is None: target_resolution = config.resolution aspect_ratio_bucket_manager = None batch_sampler = None else: target_resolution = None aspect_ratio_bucket_manager = build_aspect_ratio_bucket_manager(config=config.aspect_ratio_buckets) # TODO(ryand): Drill-down the seed parameter rather than hard-coding to 0 here. batch_sampler = AspectRatioBucketBatchSampler.from_image_sizes( bucket_manager=aspect_ratio_bucket_manager, image_sizes=base_dataset.get_image_dimensions(), batch_size=batch_size, shuffle=shuffle, seed=0, ) if sum([config.caption_templates is not None, config.caption_preset is not None]) != 1: raise ValueError("Either caption_templates or caption_preset must be set.") if config.caption_templates is not None: # Overwrites the caption field. Typically used with a ImageDirDataset that does not have captions. caption_tf = TemplateCaptionTransform( field_name="caption_prefix" if config.keep_original_captions else "caption", placeholder_str=placeholder_token, caption_templates=config.caption_templates, ) elif config.caption_preset is not None: # Overwrites the caption field. Typically used with a ImageDirDataset that does not have captions. caption_tf = TemplateCaptionTransform( field_name="caption_prefix" if config.keep_original_captions else "caption", placeholder_str=placeholder_token, caption_templates=get_preset_ti_caption_templates(config.caption_preset), ) else: raise ValueError("Either caption_templates or caption_preset must be set.") all_transforms = [caption_tf] if config.keep_original_captions: # This will only work with a HFHubImageCaptionDataset or HFDirImageCaptionDataset that already has captions. all_transforms.append( ConcatFieldsTransform( src_field_names=["caption_prefix", "caption"], dst_field_name="caption", separator=" " ) ) if config.shuffle_caption_delimiter is not None: all_transforms.append(ShuffleCaptionTransform(field_name="caption", delimiter=config.shuffle_caption_delimiter)) if vae_output_cache_dir is None: image_field_names = ["image"] if use_masks: image_field_names.append("mask") else: all_transforms.append(DropFieldTransform("mask")) all_transforms.append( SDImageTransform( image_field_names=image_field_names, fields_to_normalize_to_range_minus_one_to_one=["image"], resolution=target_resolution, aspect_ratio_bucket_manager=aspect_ratio_bucket_manager, center_crop=config.center_crop, random_flip=config.random_flip, ) ) else: # We drop the image to avoid having to either convert from PIL, or handle PIL batch collation. all_transforms.append(DropFieldTransform("image")) all_transforms.append(DropFieldTransform("mask")) vae_cache = TensorDiskCache(vae_output_cache_dir) cache_field_to_output_field = { "vae_output": "vae_output", "original_size_hw": "original_size_hw", "crop_top_left_yx": "crop_top_left_yx", } if use_masks: cache_field_to_output_field["mask"] = "mask" all_transforms.append( LoadCacheTransform( cache=vae_cache, cache_key_field="id", cache_field_to_output_field=cache_field_to_output_field, ) ) dataset = TransformDataset(base_dataset, all_transforms) if batch_sampler is None: return DataLoader( dataset, shuffle=shuffle, collate_fn=sd_image_caption_collate_fn, batch_size=batch_size, num_workers=config.dataloader_num_workers, persistent_workers=config.dataloader_num_workers > 0, ) else: return DataLoader( dataset, batch_sampler=batch_sampler, collate_fn=sd_image_caption_collate_fn, num_workers=config.dataloader_num_workers, persistent_workers=config.dataloader_num_workers > 0, ) ================================================ FILE: src/invoke_training/_shared/data/datasets/__init__.py ================================================ ================================================ FILE: src/invoke_training/_shared/data/datasets/build_dataset.py ================================================ from datasets import VerificationMode from invoke_training._shared.data.datasets.hf_image_caption_dataset import HFImageCaptionDataset from invoke_training._shared.data.datasets.hf_image_pair_preference_dataset import HFImagePairPreferenceDataset from invoke_training._shared.data.datasets.image_caption_dir_dataset import ImageCaptionDirDataset from invoke_training._shared.data.datasets.image_caption_jsonl_dataset import ImageCaptionJsonlDataset from invoke_training.config.data.dataset_config import ( HFHubImageCaptionDatasetConfig, ImageCaptionDirDatasetConfig, ImageCaptionJsonlDatasetConfig, ) from invoke_training.pipelines._experimental.sd_dpo_lora.config import HFHubImagePairPreferenceDatasetConfig def build_hf_hub_image_caption_dataset(config: HFHubImageCaptionDatasetConfig) -> HFImageCaptionDataset: return HFImageCaptionDataset.from_hub( dataset_name=config.dataset_name, hf_load_dataset_kwargs={ "name": config.dataset_config_name, "cache_dir": config.hf_cache_dir, }, image_column=config.image_column, caption_column=config.caption_column, ) def build_image_caption_jsonl_dataset(config: ImageCaptionJsonlDatasetConfig) -> HFImageCaptionDataset: return ImageCaptionJsonlDataset( jsonl_path=config.jsonl_path, image_column=config.image_column, caption_column=config.caption_column, keep_in_memory=config.keep_in_memory, ) def build_image_caption_dir_dataset(config: ImageCaptionDirDatasetConfig) -> ImageCaptionDirDataset: return ImageCaptionDirDataset( dataset_dir=config.dataset_dir, keep_in_memory=config.keep_in_memory, ) def build_hf_image_pair_preference_dataset( config: HFHubImagePairPreferenceDatasetConfig, ) -> HFImagePairPreferenceDataset: # HACK(ryand): This is currently hard-coded to just download a small slice of the very large # 'yuvalkirstain/pickapic_v2' dataset. return HFImagePairPreferenceDataset.from_hub( "yuvalkirstain/pickapic_v2", split="train", hf_load_dataset_kwargs={ "data_files": { # "validation_unique": "data/validation_unique-00000-of-00001-33ead111845fc9c4.parquet", "train": [ "data/train-00000-of-00645-b66ac786bf6fb553.parquet", "data/train-00001-of-00645-c7b349dd222d6515.parquet", "data/train-00002-of-00645-e4f54d615a978deb.parquet", "data/train-00003-of-00645-2b9d59bac8b433ff.parquet", "data/train-00004-of-00645-e4964649dc0ea543.parquet", "data/train-00005-of-00645-45e8efc0fe93f6e9.parquet", ] }, # Disable checks so that it doesn't complain that I haven't downloaded the other splits. "verification_mode": VerificationMode.NO_CHECKS, }, ) ================================================ FILE: src/invoke_training/_shared/data/datasets/hf_image_caption_dataset.py ================================================ import os import typing import datasets import torch.utils.data from PIL.Image import Image from invoke_training._shared.data.utils.resolution import Resolution class HFImageCaptionDataset(torch.utils.data.Dataset): """An image-caption dataset wrapper for Hugging Face datasets. The wrapped HF dataset can be either from the HF hub, or in Imagefolder format (https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder). """ def __init__(self, hf_dataset, image_column: str = "image", caption_column: str = "text"): column_names = hf_dataset["train"].column_names if image_column not in column_names: raise ValueError( f"The image_column='{image_column}' is not in the set of dataset column names: '{column_names}'." ) if caption_column not in column_names: raise ValueError( f"The caption_column='{caption_column}' is not in the set of dataset column names: '{column_names}'." ) self._image_column = image_column def preprocess(examples): images = [image.convert("RGB") for image in examples[image_column]] return { "image": images, "caption": examples[caption_column], } self._hf_dataset = hf_dataset["train"].with_transform(preprocess) @classmethod def from_dir( cls, dataset_dir: str, hf_load_dataset_kwargs: typing.Optional[dict[str, typing.Any]] = None, image_column: str = "image", caption_column: str = "text", ): """Initialize a HFImageCaptionDataset from a Hugging Face ImageFolder dataset directory (https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder). Args: dataset_dir (str): The path to the dataset directory. hf_load_dataset_kwargs (dict[str, typing.Any], optional): kwargs to forward to `datasets.load_dataset(...)`. image_column (str, optional): The name of the image column in the dataset. Defaults to "image". caption_column (str, optional): The name of the caption column in the dataset. Defaults to "text". """ hf_load_dataset_kwargs = hf_load_dataset_kwargs or {} data_files = {"train": os.path.join(dataset_dir, "**")} # See more about loading custom images at # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder hf_dataset = datasets.load_dataset("imagefolder", data_files=data_files, **hf_load_dataset_kwargs) return cls(hf_dataset=hf_dataset, image_column=image_column, caption_column=caption_column) @classmethod def from_hub( cls, dataset_name: str, hf_load_dataset_kwargs: typing.Optional[dict[str, typing.Any]] = None, image_column: str = "image", caption_column: str = "text", ): """Initialize a HFImageCaptionDataset from a Hugging Face Hub dataset. Args: dataset_name (str): The HF Hub dataset name (a.k.a. path). hf_load_dataset_kwargs (dict[str, typing.Any], optional): kwargs to forward to `datasets.load_dataset(...)`. image_column (str, optional): The name of the image column in the dataset. Defaults to "image". caption_column (str, optional): The name of the caption column in the dataset. Defaults to "text". """ hf_load_dataset_kwargs = hf_load_dataset_kwargs or {} hf_dataset = datasets.load_dataset(dataset_name, **hf_load_dataset_kwargs) return cls(hf_dataset=hf_dataset, image_column=image_column, caption_column=caption_column) def get_image_dimensions(self) -> list[Resolution]: """Get the dimensions of all images in the dataset. TODO(ryand): Re-think this approach. For large datasets (e.g. streaming from S3) it doesn't make sense to calculate this dynamically every time. """ image_dims: list[Resolution] = [] for i in range(len(self._hf_dataset)): example = self._hf_dataset[i] image: Image = example[self._image_column] image_dims.append(Resolution(image.height, image.width)) return image_dims def __len__(self) -> int: """Get the dataset length. Returns: int: The number of image-caption pairs in the dataset. """ return len(self._hf_dataset) def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]: """Load the dataset example at index `idx`. Raises: IndexError: If `idx` is out of range. Returns: dict: A dataset example with 3 keys: "image", "caption", and "id". The "image" key maps to a `PIL` image in RGB format. The "caption" key maps to a string. The "id" key is the example's index (often used for caching). """ example = self._hf_dataset[idx] example["id"] = idx return example ================================================ FILE: src/invoke_training/_shared/data/datasets/hf_image_pair_preference_dataset.py ================================================ import io import typing import datasets import torch.utils.data from PIL import Image class HFImagePairPreferenceDataset(torch.utils.data.Dataset): """A wrapper for the Hugging Face hub "yuvalkirstain/pickapic_v2" dataset (https://huggingface.co/datasets/yuvalkirstain/pickapic_v2). Designed to be expanded in the future to other HF image pair preference datasets. """ def __init__( self, hf_dataset, skip_no_preference=True, split: str = "train", image_0_column: str = "jpg_0", label_0_column: str = "label_0", image_1_column: str = "jpg_1", label_1_column: str = "jpg_1", caption_column: str = "caption", ): """ Args: skip_no_preference (bool, optional): If True, skip image pairs without a preference. """ column_names = hf_dataset[split].column_names for col_name in [image_0_column, label_0_column, image_1_column, label_1_column, caption_column]: if col_name not in column_names: raise ValueError(f"Column '{col_name}' is not in the set of dataset column names: '{column_names}'.") eps = 0.0001 if skip_no_preference: # Filter to only include pairs with a clear preference. def filter(example: dict[str, typing.Any]) -> bool: return abs(example["label_0"] - example["label_1"]) > eps hf_dataset = hf_dataset.filter(filter) def preprocess(examples): image_0_list = [Image.open(io.BytesIO(image)).convert("RGB") for image in examples[image_0_column]] image_1_list = [Image.open(io.BytesIO(image)).convert("RGB") for image in examples[image_1_column]] image_0_is_better = [] image_1_is_better = [] for label_0, label_1 in zip(examples["label_0"], examples["label_1"]): if (label_0 - label_1) > eps: # Label 0 is better. image_0_is_better.append(True) image_1_is_better.append(False) elif (label_1 - label_0) > eps: # Label 1 is better. image_0_is_better.append(False) image_1_is_better.append(True) else: # Tie. image_0_is_better.append(False) image_1_is_better.append(False) return { "image_0": image_0_list, "image_1": image_1_list, "prefer_0": image_0_is_better, "prefer_1": image_1_is_better, "caption": examples[caption_column], } self._hf_dataset = hf_dataset[split].with_transform(preprocess) @classmethod def from_hub( cls, dataset_name: str, skip_no_preference: bool = True, split: str = "train", hf_load_dataset_kwargs: typing.Optional[dict[str, typing.Any]] = None, ): """Initialize a HFImageCaptionDataset from a Hugging Face Hub dataset. Args: dataset_name (str): The HF Hub dataset name (a.k.a. path). hf_load_dataset_kwargs (dict[str, typing.Any], optional): kwargs to forward to `datasets.load_dataset(...)`. """ if dataset_name != "yuvalkirstain/pickapic_v2": raise NotImplementedError( "The HFImagePairPreferenceDataset class likely won't work with datasets other than " "'yuvalkirstain/pickapic_v2'." ) hf_load_dataset_kwargs = hf_load_dataset_kwargs or {} hf_dataset = datasets.load_dataset(dataset_name, **hf_load_dataset_kwargs) return cls(hf_dataset=hf_dataset, skip_no_preference=skip_no_preference, split=split) def __len__(self) -> int: """Get the dataset length. Returns: int: The number of image pairs in the dataset. """ return len(self._hf_dataset) def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]: """Load the dataset example at index `idx`. Raises: IndexError: If `idx` is out of range. Returns: dict: A dataset example with the following keys: ["id", "image_1", "caption_1", "image_2", "caption_2", "prefer_1", "prefer_2"] The image keys map to a `PIL` image in RGB format. The caption keys map to strings. The "id" key is the example's index (often used for caching). """ example = self._hf_dataset[idx] example["id"] = idx return example ================================================ FILE: src/invoke_training/_shared/data/datasets/image_caption_dir_dataset.py ================================================ import os import typing import torch.utils.data from PIL import Image from invoke_training._shared.data.utils.resolution import Resolution class ImageCaptionDirDataset(torch.utils.data.Dataset): """A dataset that loads images and captions from a directory of image files and .txt files.""" def __init__( self, dataset_dir: str, id_prefix: str = "", image_extensions: typing.Optional[list[str]] = None, caption_extension: str = ".txt", keep_in_memory: bool = False, ): """Initialize an ImageDirDataset Args: image_dir (str): The directory to load images from. id_prefix (str): A prefix added to the 'id' field in every example. image_extensions (list[str], optional): The list of image file extensions to include in the dataset (not case-sensitive). Defaults to [".jpg", ".jpeg", ".png"]. keep_in_memory (bool, optional): If True, keep all images loaded in memory. This improves performance for datasets that are small enough to be kept in memory. """ super().__init__() self._id_prefix = id_prefix if image_extensions is None: image_extensions = [".jpg", ".jpeg", ".png"] image_extensions = [ext.lower() for ext in image_extensions] # Determine the list of image paths to include in the dataset. self._image_paths: list[str] = [] for image_file in os.listdir(dataset_dir): image_path = os.path.join(dataset_dir, image_file) if os.path.isfile(image_path) and os.path.splitext(image_path)[1].lower() in image_extensions: self._image_paths.append(image_path) self._image_paths.sort() # Load captions from .txt files for each image. self._captions: list[str] = [] missing_captions: list[str] = [] for image_path in self._image_paths: caption_path = os.path.splitext(image_path)[0] + caption_extension if os.path.isfile(caption_path): with open(caption_path, "r") as f: self._captions.append(f.read().strip()) else: missing_captions.append(caption_path) if len(missing_captions) > 0: raise Exception(f"The following expected caption files are missing: {missing_captions}") self._images = None if keep_in_memory: self._images = [] for image_path in self._image_paths: self._images.append(self._load_image(image_path)) def _load_image(self, image_path: str) -> Image.Image: # We call `convert("RGB")` to drop the alpha channel from RGBA images, or to repeat channels for greyscale # images. return Image.open(image_path).convert("RGB") def get_image_dimensions(self) -> list[Resolution]: """Get the dimensions of all images in the dataset. TODO(ryand): Re-think this approach. For large datasets (e.g. streaming from S3) it doesn't make sense to calculate this dynamically every time. """ image_dims: list[Resolution] = [] for i in range(len(self._image_paths)): image_path = self._image_paths[i] image = Image.open(image_path) image_dims.append(Resolution(image.height, image.width)) return image_dims def __len__(self) -> int: return len(self._image_paths) def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]: image = self._images[idx] if self._images is not None else self._load_image(self._image_paths[idx]) return {"id": f"{self._id_prefix}{idx}", "image": image, "caption": self._captions[idx]} ================================================ FILE: src/invoke_training/_shared/data/datasets/image_caption_jsonl_dataset.py ================================================ import typing from pathlib import Path import torch.utils.data from PIL import Image from pydantic import BaseModel from invoke_training._shared.data.utils.resolution import Resolution from invoke_training._shared.utils.jsonl import load_jsonl, save_jsonl IMAGE_COLUMN_DEFAULT = "image" CAPTION_COLUMN_DEFAULT = "text" MASK_COLUMN_DEFAULT = "mask" class ImageCaptionExample(BaseModel): image_path: str mask_path: str | None = None caption: str class ImageCaptionJsonlDataset(torch.utils.data.Dataset): """A dataset that loads images and captions from a directory of image files and .txt files.""" def __init__( self, jsonl_path: Path | str, image_column: str = IMAGE_COLUMN_DEFAULT, caption_column: str = CAPTION_COLUMN_DEFAULT, keep_in_memory: bool = False, ): super().__init__() self._jsonl_path = Path(jsonl_path) self._image_column = image_column self._caption_column = caption_column data = load_jsonl(jsonl_path) examples: list[ImageCaptionExample] = [] for d in data: # Clear error messages here are helpful in the Gradio UI. if image_column not in d: raise ValueError(f"Column '{image_column}' not found in jsonl file '{jsonl_path}'.") if caption_column not in d: raise ValueError(f"Column '{caption_column}' not found in jsonl file '{jsonl_path}'.") examples.append( ImageCaptionExample( image_path=d[image_column], mask_path=d.get(MASK_COLUMN_DEFAULT, None), caption=d[caption_column] ) ) self.examples = examples self._keep_in_memory = keep_in_memory self._example_cache: dict[int, dict[str, typing.Any]] = {} def save_jsonl(self): data = [] for example in self.examples: data.append( { self._image_column: example.image_path, self._caption_column: example.caption, MASK_COLUMN_DEFAULT: example.mask_path, } ) save_jsonl(data, self._jsonl_path) def _get_image_path(self, idx: int) -> str: image_path = self.examples[idx].image_path image_path = Path(image_path) # image_path could be either absolute, or relative to the jsonl file. if not image_path.is_absolute(): image_path = self._jsonl_path.parent / image_path return image_path def _get_mask_path(self, idx: int) -> str: mask_path = self.examples[idx].mask_path mask_path = Path(mask_path) # mask_path could be either absolute, or relative to the jsonl file. if not mask_path.is_absolute(): mask_path = self._jsonl_path.parent / mask_path return mask_path def _load_image(self, image_path: str) -> Image.Image: # We call `convert("RGB")` to drop the alpha channel from RGBA images, or to repeat channels for greyscale # images. return Image.open(image_path).convert("RGB") def _load_mask(self, mask_path: str) -> Image.Image: return Image.open(mask_path).convert("L") def _load_example(self, idx: int) -> dict[str, typing.Any]: example = { "id": str(idx), "image": self._load_image(self._get_image_path(idx)), "caption": self.examples[idx].caption, } if self.examples[idx].mask_path: example["mask"] = self._load_mask(self._get_mask_path(idx)) return example def get_image_dimensions(self) -> list[Resolution]: """Get the dimensions of all images in the dataset. TODO(ryand): Re-think this approach. For large datasets (e.g. streaming from S3) it doesn't make sense to calculate this dynamically every time. """ image_dims: list[Resolution] = [] for i in range(len(self.examples)): image = Image.open(self._get_image_path(i)) image_dims.append(Resolution(image.height, image.width)) return image_dims def __len__(self) -> int: return len(self.examples) def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]: if self._keep_in_memory: if idx not in self._example_cache: self._example_cache[idx] = self._load_example(idx) # Return a shallow copy of the example to prevent the caller from modifying the cached example. # Shallow rather than deep, because we don't want to copy the image data. return self._example_cache[idx].copy() return self._load_example(idx) ================================================ FILE: src/invoke_training/_shared/data/datasets/image_dir_dataset.py ================================================ import os import typing import torch.utils.data from PIL import Image from invoke_training._shared.data.utils.resolution import Resolution class ImageDirDataset(torch.utils.data.Dataset): """A dataset that loads image files from a directory.""" def __init__( self, image_dir: str, id_prefix: str = "", image_extensions: typing.Optional[list[str]] = None, keep_in_memory: bool = False, ): """Initialize an ImageDirDataset Args: image_dir (str): The directory to load images from. id_prefix (str): A prefix added to the 'id' field in every example. image_extensions (list[str], optional): The list of image file extensions to include in the dataset (not case-sensitive). Defaults to [".jpg", ".jpeg", ".png"]. keep_in_memory (bool, optional): If True, keep all images loaded in memory. This improves performance for datasets that are small enough to be kept in memory. """ super().__init__() self._id_prefix = id_prefix if image_extensions is None: image_extensions = [".jpg", ".jpeg", ".png"] image_extensions = [ext.lower() for ext in image_extensions] self._image_paths = [] for image_file in os.listdir(image_dir): image_path = os.path.join(image_dir, image_file) if os.path.isfile(image_path) and os.path.splitext(image_path)[1].lower() in image_extensions: self._image_paths.append(image_path) self._images = None if keep_in_memory: self._images = [] for image_path in self._image_paths: self._images.append(self._load_image(image_path)) def _load_image(self, image_path: str) -> Image.Image: # We call `convert("RGB")` to drop the alpha channel from RGBA images, or to repeat channels for greyscale # images. return Image.open(image_path).convert("RGB") def get_image_dimensions(self) -> list[Resolution]: """Get the dimensions of all images in the dataset. TODO(ryand): Re-think this approach. For large datasets (e.g. streaming from S3) it doesn't make sense to calculate this dynamically every time. """ image_dims: list[Resolution] = [] for i in range(len(self._image_paths)): image_path = self._image_paths[i] image = Image.open(image_path) image_dims.append(Resolution(image.height, image.width)) return image_dims def __len__(self) -> int: return len(self._image_paths) def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]: image = self._images[idx] if self._images is not None else self._load_image(self._image_paths[idx]) return {"id": f"{self._id_prefix}{idx}", "image": image} ================================================ FILE: src/invoke_training/_shared/data/datasets/image_pair_preference_dataset.py ================================================ import os import typing from pathlib import Path import torch.utils.data from PIL import Image from invoke_training._shared.utils.jsonl import load_jsonl, save_jsonl class ImagePairPreferenceDataset(torch.utils.data.Dataset): def __init__(self, dataset_dir: str): super().__init__() self._dataset_dir = dataset_dir self._metadata = load_jsonl(Path(dataset_dir) / "metadata.jsonl") @classmethod def save_metadata( cls, metadata: list[dict[str, typing.Any]], dataset_dir: str | Path, metadata_file: str = "metadata.jsonl" ) -> Path: """Load the dataset metadata from metadata.jsonl.""" metadata_path = Path(dataset_dir) / metadata_file save_jsonl(metadata, metadata_path) return metadata_path def __len__(self) -> int: return len(self._metadata) def __getitem__(self, idx: int) -> typing.Dict[str, typing.Any]: # We call `convert("RGB")` to drop the alpha channel from RGBA images, or to repeat channels for greyscale # images. example = self._metadata[idx] image_0_path = os.path.join(self._dataset_dir, example["image_0"]) image_1_path = os.path.join(self._dataset_dir, example["image_1"]) return { "id": str(idx), "image_0": Image.open(image_0_path).convert("RGB"), "image_1": Image.open(image_1_path).convert("RGB"), "caption": example["prompt"], "prefer_0": example["prefer_0"], "prefer_1": example["prefer_1"], } ================================================ FILE: src/invoke_training/_shared/data/datasets/transform_dataset.py ================================================ import typing import torch.utils.data # The data type expected to be produced by the base dataset and handled by transforms. DataType = typing.Dict[str, typing.Any] TransformType = typing.Callable[[DataType], DataType] class TransformDataset(torch.utils.data.Dataset): """A Dataset that wraps a base dataset and applies callable transforms to its outputs.""" def __init__(self, base_dataset: torch.utils.data.Dataset, transforms: list[TransformType]) -> None: super().__init__() self._base_dataset = base_dataset self._transforms = transforms def __len__(self) -> int: return len(self._base_dataset) def __getitem__(self, idx: int) -> DataType: example = self._base_dataset[idx] for t in self._transforms: example = t(example) return example ================================================ FILE: src/invoke_training/_shared/data/samplers/__init__.py ================================================ ================================================ FILE: src/invoke_training/_shared/data/samplers/aspect_ratio_bucket_batch_sampler.py ================================================ import copy import logging import math import random from typing import Iterator from torch.utils.data import Sampler from invoke_training._shared.data.utils.aspect_ratio_bucket_manager import AspectRatioBucketManager from invoke_training._shared.data.utils.resolution import Resolution AspectRatioBuckets = dict[Resolution, list[int]] class AspectRatioBucketBatchSampler(Sampler[list[int]]): """A batch sampler that adheres to aspect ratio buckets.""" def __init__( self, buckets: AspectRatioBuckets, batch_size: int, shuffle: bool = False, seed: int | None = None, ) -> None: """Initialize AspectRatioBucketBatchSampler. For most use cases, initialize via AspectRatioBucketBatchSampler.from_image_sizes(...). """ self._buckets = buckets self._batch_size = batch_size self._shuffle = shuffle self._random = random.Random(seed) def __str__(self) -> str: buckets = self.get_buckets() bucket_resolutions = sorted(list(buckets.keys())) s = "" for bucket_resolution in bucket_resolutions: bucket_images = buckets[bucket_resolution] s += f" {bucket_resolution.to_tuple()}: {len(bucket_images)}\n" return s @classmethod def from_image_sizes( cls, bucket_manager: AspectRatioBucketManager, image_sizes: list[Resolution], batch_size: int, shuffle: bool = False, seed: int | None = None, ): """Initialize from an AspectRatioBucketManager and the list of dataset image resolutions.""" buckets = cls._build_bucket_to_index_map(bucket_manager, image_sizes) return cls(buckets=buckets, batch_size=batch_size, shuffle=shuffle, seed=seed) @classmethod def _build_bucket_to_index_map( cls, bucket_manager: AspectRatioBucketManager, image_sizes: list[Resolution], ) -> AspectRatioBuckets: bucket_to_indexes: AspectRatioBuckets = dict() for bucket_resolution in bucket_manager.buckets: bucket_to_indexes[bucket_resolution] = [] for index, image_size in enumerate(image_sizes): aspect_ratio_bucket = bucket_manager.get_aspect_ratio_bucket(image_size) bucket_to_indexes[aspect_ratio_bucket].append(index) return bucket_to_indexes def get_buckets(self) -> AspectRatioBuckets: return copy.deepcopy(self._buckets) def __iter__(self) -> Iterator[list[int]]: batches: list[list[int]] = [] # TODO(ryand): If self._shuffle == False, should we still shuffle just with a fixed seed every time? If we # don't shuffle at all then all of the batches from a bucket will be grouped together. If there's a correlation # between aspect ratio and image content in a dataset, this could result in unevenly distributed image content # over the dataset. for bucket_resolution in sorted(list(self._buckets.keys())): ordered_bucket_images = self._buckets[bucket_resolution].copy() if self._shuffle: # Shuffle the images within a bucket. self._random.shuffle(ordered_bucket_images) # Prepare batches for a single bucket. batch_start = 0 while batch_start < len(ordered_bucket_images): batch_end = min(batch_start + self._batch_size, len(ordered_bucket_images)) batches.append(ordered_bucket_images[batch_start:batch_end]) batch_start += self._batch_size if self._shuffle: # We've already shuffled the images within each bucket, now we shuffle the batches. self._random.shuffle(batches) yield from batches def __len__(self) -> int: num_batches = 0 for bucket_images in self._buckets.values(): num_batches += math.ceil(len(bucket_images) / self._batch_size) return num_batches def log_aspect_ratio_buckets(logger: logging.Logger, batch_sampler: AspectRatioBucketBatchSampler): """Utility function for logging the aspect ratio buckets.""" if not isinstance(batch_sampler, AspectRatioBucketBatchSampler): return log = "Aspect Ratio Buckets:\n" log += str(batch_sampler) logger.info(log) ================================================ FILE: src/invoke_training/_shared/data/samplers/batch_offset_sampler.py ================================================ import typing from torch.utils.data import Sampler class BatchOffsetSampler(Sampler[int]): """A sampler that wraps a batch sampler and applies an offset to all returned batch elements.""" def __init__(self, sampler: Sampler[int], offset: int): self._sampler = sampler self._offset = offset def __iter__(self) -> typing.Iterator[int]: for batch in self._sampler: offset_batch = [x + self._offset for x in batch] yield offset_batch def __len__(self) -> int: return len(self._sampler) ================================================ FILE: src/invoke_training/_shared/data/samplers/concat_sampler.py ================================================ import itertools import typing from torch.utils.data import Sampler T_co = typing.TypeVar("T_co", covariant=True) class ConcatSampler(Sampler[T_co]): """A meta-Sampler that concatenates multiple samplers. Example: sampler 1: ABCD sampler 2: EFG sampler 3: HIJKLM ConcatSampler: ABCDEFGHIJKLM """ def __init__(self, samplers: list[Sampler[T_co] | typing.Iterable[T_co]]) -> None: self._samplers = samplers def __iter__(self) -> typing.Iterator[T_co]: return itertools.chain(*self._samplers) def __len__(self) -> int: return sum([len(s) for s in self._samplers]) ================================================ FILE: src/invoke_training/_shared/data/samplers/interleaved_sampler.py ================================================ import typing from torch.utils.data import Sampler T_co = typing.TypeVar("T_co", covariant=True) class InterleavedSampler(Sampler[T_co]): """A meta-Sampler that interleaves multiple samplers. The length of this sampler is based on the length of the shortest input sampler. All samplers will contribute the same number of samples to the interleaved output. Example: sampler 1: ABCD sampler 2: EFG sampler 3: HIJKLM interleaved sampler: AEHBFICGJ """ def __init__(self, samplers: list[Sampler[T_co] | typing.Iterable[T_co]]) -> None: self._samplers = samplers self._min_sampler_len = min([len(s) for s in self._samplers]) def __iter__(self) -> typing.Iterator[T_co]: sampler_iters = [iter(s) for s in self._samplers] while True: samples = [] for sampler_iter in sampler_iters: try: samples.append(next(sampler_iter)) except StopIteration: # The end of the shortest sampler has been reached. return yield from samples def __len__(self) -> int: return self._min_sampler_len * len(self._samplers) ================================================ FILE: src/invoke_training/_shared/data/samplers/offset_sampler.py ================================================ import typing from torch.utils.data import Sampler class OffsetSampler(Sampler[int]): """A sampler that wraps another sampler and applies an offset to all returned values.""" def __init__(self, sampler: Sampler[int], offset: int): self._sampler = sampler self._offset = offset def __iter__(self) -> typing.Iterator[int]: for idx in self._sampler: yield idx + self._offset def __len__(self) -> int: return len(self._sampler) ================================================ FILE: src/invoke_training/_shared/data/transforms/__init__.py ================================================ ================================================ FILE: src/invoke_training/_shared/data/transforms/caption_prefix_transform.py ================================================ import typing class CaptionPrefixTransform: """A transform that adds a prefix to all example captions.""" def __init__(self, caption_field_name: str, prefix: str): self._caption_field_name = caption_field_name self._prefix = prefix def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]: data[self._caption_field_name] = self._prefix + data[self._caption_field_name] return data ================================================ FILE: src/invoke_training/_shared/data/transforms/concat_fields_transform.py ================================================ import typing class ConcatFieldsTransform: """A transform that concatenate multiple string fields.""" def __init__(self, src_field_names: list[str], dst_field_name: str, separator: str = " "): self._src_field_names = src_field_names self._dst_field_name = dst_field_name self._separator = separator def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]: result = self._separator.join([data[field_name] for field_name in self._src_field_names]) data[self._dst_field_name] = result return data ================================================ FILE: src/invoke_training/_shared/data/transforms/constant_field_transform.py ================================================ import typing class ConstantFieldTransform: """A simple transform that adds a constant field to every example.""" def __init__(self, field_name: str, field_value: typing.Any): self._field_name = field_name self._field_value = field_value def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]: data[self._field_name] = self._field_value return data ================================================ FILE: src/invoke_training/_shared/data/transforms/drop_field_transform.py ================================================ import typing class DropFieldTransform: """A simple transform that drops a field from an example.""" def __init__(self, field_to_drop: str): self._field_to_drop = field_to_drop def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]: if self._field_to_drop in data: del data[self._field_to_drop] return data ================================================ FILE: src/invoke_training/_shared/data/transforms/flux_image_transform.py ================================================ import typing from torchvision import transforms from torchvision.transforms.functional import crop from invoke_training._shared.data.utils.aspect_ratio_bucket_manager import AspectRatioBucketManager, Resolution from invoke_training._shared.data.utils.resize import resize_to_cover class FluxImageTransform: """A transform that prepares and augments images for Flux.1-dev training.""" def __init__( self, image_field_names: list[str], fields_to_normalize_to_range_minus_one_to_one: list[str], resolution: int | None = 512, aspect_ratio_bucket_manager: AspectRatioBucketManager | None = None, random_flip: bool = True, center_crop: bool = True, ): """Initialize FluxImageTransform. Args: image_field_names (list[str]): The field names of the images to be transformed. resolution (int): The image resolution that will be produced. One of `resolution` and `aspect_ratio_bucket_manager` should be non-None. aspect_ratio_bucket_manager (AspectRatioBucketManager): The AspectRatioBucketManager used to determine the target resolution for each image. One of `resolution` and `aspect_ratio_bucket_manager` should be non-None. center_crop (bool, optional): If True, crop to the center of the image to achieve the target resolution. If False, crop at a random location. random_flip (bool, optional): Whether to apply a random horizontal flip to the images. """ self.image_field_names = image_field_names self.fields_to_normalize_to_range_minus_one_to_one = fields_to_normalize_to_range_minus_one_to_one self.resolution = resolution self.aspect_ratio_bucket_manager = aspect_ratio_bucket_manager self.random_flip = random_flip self.center_crop = center_crop def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]: # noqa: C901 image_fields: dict = {} for field_name in self.image_field_names: image_fields[field_name] = data[field_name] # Get the first image to determine original size and resolution first_image = next(iter(image_fields.values())) original_size_hw = (first_image.height, first_image.width) for field_name, image in image_fields.items(): # Determine the target image resolution. if self.resolution is not None: resolution = self.resolution resolution_obj = Resolution(resolution, resolution) else: resolution_obj = self.aspect_ratio_bucket_manager.get_aspect_ratio_bucket( Resolution.parse(original_size_hw) ) image = resize_to_cover(image, resolution_obj) # Apply cropping and record top left crop position if self.center_crop: top_left_y = max(0, (image.height - resolution_obj.height) // 2) top_left_x = max(0, (image.width - resolution_obj.width) // 2) image = transforms.CenterCrop(resolution_obj.to_tuple())(image) else: crop_transform = transforms.RandomCrop(resolution_obj.to_tuple()) top_left_y, top_left_x, h, w = crop_transform.get_params(image, resolution_obj.to_tuple()) image = crop(image, top_left_y, top_left_x, resolution_obj.height, resolution_obj.width) # Apply random flip and update top left crop position accordingly if self.random_flip: # TODO: Use a seed for repeatable results import random if random.random() < 0.5: top_left_x = original_size_hw[1] - image.width - top_left_x image = transforms.RandomHorizontalFlip(p=1.0)(image) image = transforms.ToTensor()(image) if field_name in self.fields_to_normalize_to_range_minus_one_to_one: image_fields[field_name] = transforms.Normalize([0.5], [0.5])(image) else: image_fields[field_name] = image # Store the processed images and metadata for field_name, image in image_fields.items(): data[field_name] = image # Add metadata fields expected by VAE caching data["original_size_hw"] = original_size_hw data["crop_top_left_yx"] = (top_left_y, top_left_x) return data ================================================ FILE: src/invoke_training/_shared/data/transforms/load_cache_transform.py ================================================ import typing from invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache class LoadCacheTransform: """A transform that loads data from a TensorDiskCache.""" def __init__( self, cache: TensorDiskCache, cache_key_field: str, cache_field_to_output_field: typing.Dict[str, str] ): """Initialize LoadCacheTransform. Args: cache (TensorDiskCache): The cache to load from. cache_key_field (str): The name of the field to use as the cache key. cache_field_to_output_field (typing.Dict[str, str]): A map of field names in the cached data to the field names where they should be inserted in the example data. """ self._cache = cache self._cache_key_field = cache_key_field self._cache_field_to_output_field = cache_field_to_output_field def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]: key = data[self._cache_key_field] cache_data = self._cache.load(key) for src, dst in self._cache_field_to_output_field.items(): data[dst] = cache_data[src] return data ================================================ FILE: src/invoke_training/_shared/data/transforms/sd_image_transform.py ================================================ import random import typing from torchvision import transforms from torchvision.transforms.functional import crop from invoke_training._shared.data.utils.aspect_ratio_bucket_manager import AspectRatioBucketManager, Resolution from invoke_training._shared.data.utils.resize import resize_to_cover class SDImageTransform: """A transform that prepares and augments images for Stable Diffusion training.""" def __init__( self, image_field_names: list[str], fields_to_normalize_to_range_minus_one_to_one: list[str], resolution: int | tuple[int, int] | Resolution | None, aspect_ratio_bucket_manager: AspectRatioBucketManager | None = None, center_crop: bool = True, random_flip: bool = False, orig_size_field_name: str = "original_size_hw", crop_field_name: str = "crop_top_left_yx", ): """Initialize SDImageTransform. Args: image_field_names (list[str]): The field names of the images to be transformed. resolution (Resolution): The image resolution that will be produced. One of `resolution` and `aspect_ratio_bucket_manager` should be non-None. aspect_ratio_bucket_manager (AspectRatioBucketManager): The AspectRatioBucketManager used to determine the target resolution for each image. One of `resolution` and `aspect_ratio_bucket_manager` should be non-None. center_crop (bool, optional): If True, crop to the center of the image to achieve the target resolution. If False, crop at a random location. random_flip (bool, optional): Whether to apply a random horizontal flip to the images. """ self._image_field_names = image_field_names self._fields_to_normalize_to_range_minus_one_to_one = fields_to_normalize_to_range_minus_one_to_one if resolution is not None and aspect_ratio_bucket_manager is not None: raise ValueError("Only one of `resolution` or `aspect_ratio_bucket_manager` should be set.") if resolution is None and aspect_ratio_bucket_manager is None: raise ValueError("One of `resolution` or `aspect_ratio_bucket_manager` must be set.") self._resolution = Resolution.parse(resolution) if resolution is not None else None self._aspect_ratio_bucket_manager = aspect_ratio_bucket_manager self._center_crop_enabled = center_crop self._random_flip_enabled = random_flip self._flip_transform = transforms.RandomHorizontalFlip(p=1.0) self._to_tensor_transform = transforms.ToTensor() # Convert pixel values from range [0, 1.0] to range [-1.0, 1.0]. # Normalize applies the following transform: out = (in - 0.5) / 0.5 self._normalize_image_transform = transforms.Normalize([0.5], [0.5]) self._orig_size_field_name = orig_size_field_name self._crop_field_name = crop_field_name def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]: # noqa: C901 # This SDXL image pre-processing logic is adapted from: # https://github.com/huggingface/diffusers/blob/7b07f9812a58bfa96c06ed8ffe9e6b584286e2fd/examples/text_to_image/train_text_to_image_lora_sdxl.py#L850-L873 image_fields: dict = {} for field_name in self._image_field_names: image_fields[field_name] = data[field_name] sizes = [image.size for image in image_fields.values()] # All images should have the same size. assert all(size == sizes[0] for size in sizes) # Helper function to access the first image, which is sometimes used to infer the shape of all images. def get_first_image(): return next(iter(image_fields.values())) original_size_hw = (get_first_image().height, get_first_image().width) # Determine the target image resolution. if self._resolution is not None: resolution = self._resolution else: resolution = self._aspect_ratio_bucket_manager.get_aspect_ratio_bucket(Resolution.parse(original_size_hw)) # Resize to cover the target resolution while preserving aspect ratio. for field_name, image in image_fields.items(): image_fields[field_name] = resize_to_cover(image, resolution) # Apply cropping, and record top left crop position. if self._center_crop_enabled: top_left_y = max(0, (get_first_image().height - resolution.height) // 2) top_left_x = max(0, (get_first_image().width - resolution.width) // 2) else: crop_transform = transforms.RandomCrop(resolution.to_tuple()) top_left_y, top_left_x, h, w = crop_transform.get_params(get_first_image(), resolution.to_tuple()) for field_name, image in image_fields.items(): image_fields[field_name] = crop(image, top_left_y, top_left_x, resolution.height, resolution.width) # Apply random flip and update top left crop position accordingly. # TODO(ryand): Use a seed for repeatable results. if self._random_flip_enabled and random.random() < 0.5: top_left_x = original_size_hw[1] - get_first_image().width - top_left_x for field_name, image in image_fields.items(): image_fields[field_name] = self._flip_transform(image) crop_top_left_yx = (top_left_y, top_left_x) # Convert to Tensors. for field_name, image in image_fields.items(): image_fields[field_name] = self._to_tensor_transform(image) # Normalize to range [-1.0, 1.0]. # HACK(ryand): We should find a better way to determine the normalization range of each image field. for field_name, image in image_fields.items(): if field_name in self._fields_to_normalize_to_range_minus_one_to_one: image_fields[field_name] = self._normalize_image_transform(image) data[self._orig_size_field_name] = original_size_hw data[self._crop_field_name] = crop_top_left_yx for field_name, image in image_fields.items(): data[field_name] = image return data ================================================ FILE: src/invoke_training/_shared/data/transforms/shuffle_caption_transform.py ================================================ import typing import numpy as np class ShuffleCaptionTransform: """A transform that applies shuffle transformations to character-delimited captions. Example: - Original: "unreal engine, render of sci-fi helmet, dramatic lighting" - Shuffled: "render of sci-fi helmet, unreal engine, dramatic lighting" """ def __init__(self, field_name: str, delimiter: str = ",", seed: int = 0): self._field_name = field_name self._delimiter = delimiter self._rng = np.random.default_rng(seed) def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]: caption: str = data[self._field_name] caption_chunks = caption.split(self._delimiter) caption_chunks = [s.strip() for s in caption_chunks] self._rng.shuffle(caption_chunks) join_str = self._delimiter + " " data[self._field_name] = join_str.join(caption_chunks) return data ================================================ FILE: src/invoke_training/_shared/data/transforms/template_caption_transform.py ================================================ import typing import numpy as np class TemplateCaptionTransform: """A simple transform that constructs a caption for each example by combining a caption template with the placeholder string. """ def __init__(self, field_name: str, placeholder_str: str, caption_templates: list[str], seed: int = 0): self._field_name = field_name self._placeholder_str = placeholder_str self._caption_templates = caption_templates self._rng = np.random.default_rng(seed) def __call__(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]: caption = self._rng.choice(self._caption_templates).format(self._placeholder_str) # Assert that the template was well-formed such that the placeholder string is in the output caption. assert self._placeholder_str in caption data[self._field_name] = caption return data ================================================ FILE: src/invoke_training/_shared/data/transforms/tensor_disk_cache.py ================================================ import os import typing import torch class TensorDiskCache: """A data cache that caches `torch.Tensor`s on disk.""" def __init__(self, cache_dir: str): super().__init__() self._cache_dir = cache_dir os.makedirs(self._cache_dir, exist_ok=True) def _get_path(self, key: int): """Get the cache file path for `key`. Args: key (int): The cache key. Returns: str: The cache file path. """ return os.path.join(self._cache_dir, f"{key}.pt") def save(self, key: int, data: typing.Dict[str, torch.Tensor]): """Save data in the cache. Raises: AssertionError: If an entry already exists in the cache for this `key`. Args: key (int): The cache key. data (typing.Dict[str, torch.Tensor]): The data to save. """ # torch.save() supports a range of different data types, but it is cleaner if we force everyone to use a dict. # This allows for more reusable cache loading code. assert isinstance(data, dict) save_path = self._get_path(key) assert not os.path.exists(save_path) torch.save(data, save_path) def load(self, key: int) -> typing.Dict[str, torch.Tensor]: """Load data from the cache. Args: key (int): The cache key to load. Returns: typing.Dict[str, torch.Tensor]: Data loaded from the cache. """ return torch.load(self._get_path(key)) ================================================ FILE: src/invoke_training/_shared/data/utils/__init__.py ================================================ ================================================ FILE: src/invoke_training/_shared/data/utils/aspect_ratio_bucket_manager.py ================================================ from invoke_training._shared.data.utils.resolution import Resolution class AspectRatioBucketManager: def __init__(self, buckets: set[Resolution]): self.buckets = buckets @classmethod def from_constraints(cls, target_resolution: int, start_dim: int, end_dim: int, divisible_by: int) -> None: buckets = cls.build_aspect_ratio_buckets( target_resolution=target_resolution, start_dim=start_dim, end_dim=end_dim, divisible_by=divisible_by, ) return cls(buckets) @classmethod def build_aspect_ratio_buckets( cls, target_resolution: int, start_dim: int, end_dim: int, divisible_by: int ) -> set[Resolution]: """Prepare a set of aspect ratios. Args: target_resolution (Resolution): All resolutions in the returned set will aim to have close to (but <=) `target_resolution * target_resolution` pixels. start_dim (int): end_dim (int): divisible_by (int): All dimensions in the returned set of resolutions will be divisible by `divisible_by`. Returns: set[tuple[int, int]]: The aspect ratio bucket resolutions. """ # Validate target_resolution. assert target_resolution % divisible_by == 0 # Validate start_dim, end_dim. assert start_dim <= end_dim assert start_dim % divisible_by == 0 assert end_dim % divisible_by == 0 target_size = target_resolution * target_resolution buckets = set() height = start_dim while height <= end_dim: width = (target_size // height) // divisible_by * divisible_by buckets.add(Resolution(height, width)) buckets.add(Resolution(width, height)) height += divisible_by return buckets def get_aspect_ratio_bucket(self, resolution: Resolution): """Get the bucket with the closest aspect ratio to 'resolution'.""" # Note: If this is ever found to be a bottleneck, there is a clearly-more-efficient implementation using bisect. return min(self.buckets, key=lambda x: abs(x.aspect_ratio() - resolution.aspect_ratio())) ================================================ FILE: src/invoke_training/_shared/data/utils/resize.py ================================================ import math from PIL.Image import Image from torchvision import transforms from invoke_training._shared.data.utils.resolution import Resolution def resize_to_cover(image: Image, size_to_cover: Resolution) -> Image: """Resize image to the smallest size that covers 'size_to_cover' while preserving its aspect ratio. In other words, achieve the following: - resized_height >= size_to_cover.height - resized_width >= size_to_cover.width - resized_height == size_to_cover.height or resized_width == size_to_cover.width - 'image' aspect ratio is preserved. """ scale_to_height = size_to_cover.height / image.height scale_to_width = size_to_cover.width / image.width if scale_to_height > scale_to_width: resize_height = size_to_cover.height resize_width = math.ceil(image.width * scale_to_height) else: resize_width = size_to_cover.width resize_height = math.ceil(image.height * scale_to_width) resize_transform = transforms.Resize( (resize_height, resize_width), interpolation=transforms.InterpolationMode.BILINEAR ) return resize_transform(image) ================================================ FILE: src/invoke_training/_shared/data/utils/resolution.py ================================================ from typing import Union class Resolution: def __init__(self, height: int, width: int): self.height = height self.width = width @classmethod def parse(cls, resolution: Union[int, tuple[int, int], "Resolution"]): """Initialize a Resolution object from another type.""" if isinstance(resolution, int): # Assume square resolution. return cls(resolution, resolution) elif isinstance(resolution, tuple): height, width = resolution return cls(height, width) elif isinstance(resolution, cls): return cls(resolution.height, resolution.width) else: raise ValueError(f"Unsupported resolution type: '{type(resolution)}'.") def aspect_ratio(self): return self.height / self.width def to_tuple(self) -> tuple[int, int]: return (self.height, self.width) def __eq__(self, other: "Resolution") -> bool: return self.to_tuple() == other.to_tuple() def __lt__(self, other: "Resolution") -> bool: return self.to_tuple() < other.to_tuple() def __hash__(self): return hash(self.to_tuple()) ================================================ FILE: src/invoke_training/_shared/flux/encoding_utils.py ================================================ import logging from typing import List, Optional, Tuple, Union import torch from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast def get_clip_prompt_embeds( prompt: Union[str, List[str]], tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, device: torch.device, num_images_per_prompt: int = 1, tokenizer_max_length: int = 77, logger: logging.Logger | None = None, ) -> torch.FloatTensor: """Encodes the prompt using CLIP text encoder and returns pooled embeddings.""" prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) # Process text input with the tokenizer text_inputs = tokenizer( prompt, padding="max_length", max_length=tokenizer_max_length, truncation=True, return_overflowing_tokens=False, return_length=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids # Check if truncation occurred if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1]) if logger is not None: logger.warning(f"Warning: The following part of your input was truncated: {removed_text}") # Get prompt embeddings through the text encoder prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds.pooler_output prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) # Duplicate text embeddings for each generation per prompt prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds def get_t5_prompt_embeds( prompt: Union[str, List[str]], tokenizer: T5TokenizerFast, text_encoder: T5EncoderModel, device: torch.device, num_images_per_prompt: int = 1, tokenizer_max_length: int = 512, logger: logging.Logger | None = None, ) -> torch.FloatTensor: """Encodes the prompt using T5 text encoder.""" prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) # Process text input with the tokenizer text_inputs = tokenizer( prompt, padding="max_length", max_length=tokenizer_max_length, truncation=True, return_length=False, return_overflowing_tokens=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids # Check if truncation occurred if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1]) if logger is not None: logger.warning(f"Warning: The following part of your input was truncated: {removed_text}") # Get prompt embeddings through the text encoder prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)[0] dtype = text_encoder.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) # Get shape and duplicate for multiple generations _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds def handle_lora_scale( clip_text_encoder: CLIPTextModel, t5_text_encoder: T5EncoderModel, lora_scale: Optional[float] = None, use_peft_backend: bool = False, ): """Handles LoRA scale adjustments for text encoders.""" if lora_scale is not None and use_peft_backend: from peft.utils import scale_lora_layers # Apply LoRA scaling to text encoders if they exist if clip_text_encoder is not None: scale_lora_layers(clip_text_encoder, lora_scale) if t5_text_encoder is not None: scale_lora_layers(t5_text_encoder, lora_scale) return True return False def reset_lora_scale( clip_text_encoder: CLIPTextModel, t5_text_encoder: T5EncoderModel, lora_scale: Optional[float] = None, lora_applied: bool = False, use_peft_backend: bool = False, ): """Resets LoRA scale for text encoders if it was applied.""" if lora_applied and use_peft_backend: from peft.utils import unscale_lora_layers # Reset LoRA scaling if clip_text_encoder is not None: unscale_lora_layers(clip_text_encoder, lora_scale) if t5_text_encoder is not None: unscale_lora_layers(t5_text_encoder, lora_scale) # A lot of this code was adapted from: # https://github.com/huggingface/diffusers/blob/ea81a4228d8ff16042c3ccaf61f0e588e60166cd/src/diffusers/pipelines/flux/pipeline_flux.py#L310-L387 def encode_prompt( prompt: Union[str, List[str]], prompt_2: Optional[Union[str, List[str]]], clip_tokenizer: CLIPTokenizer, t5_tokenizer: T5TokenizerFast, clip_text_encoder: CLIPTextModel, t5_text_encoder: T5EncoderModel, device: torch.device, num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, lora_scale: Optional[float] = None, use_peft_backend: bool = False, clip_tokenizer_max_length: int = 77, t5_tokenizer_max_length: int = 512, logger: logging.Logger | None = None, ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: """ Encodes the prompt using both CLIP and T5 text encoders. Returns: Tuple containing: - T5 text embeddings - CLIP pooled embeddings - Text IDs """ # Apply LoRA scale if needed lora_applied = handle_lora_scale( clip_text_encoder=clip_text_encoder, t5_text_encoder=t5_text_encoder, lora_scale=lora_scale, use_peft_backend=use_peft_backend, ) # If no pre-generated embeddings, create them if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 # Get CLIP pooled embeddings pooled_prompt_embeds = get_clip_prompt_embeds( prompt=prompt, tokenizer=clip_tokenizer, text_encoder=clip_text_encoder, device=device, num_images_per_prompt=num_images_per_prompt, tokenizer_max_length=clip_tokenizer_max_length, ) # Get T5 text embeddings prompt_embeds = get_t5_prompt_embeds( prompt=prompt_2, tokenizer=t5_tokenizer, text_encoder=t5_text_encoder, device=device, num_images_per_prompt=num_images_per_prompt, tokenizer_max_length=t5_tokenizer_max_length, ) # Reset LoRA scale if it was applied reset_lora_scale( clip_text_encoder=clip_text_encoder, t5_text_encoder=t5_text_encoder, lora_scale=lora_scale, lora_applied=lora_applied, use_peft_backend=use_peft_backend, ) # Create text_ids placeholder for model dtype = clip_text_encoder.dtype if clip_text_encoder is not None else t5_text_encoder.dtype text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) return prompt_embeds, pooled_prompt_embeds, text_ids ================================================ FILE: src/invoke_training/_shared/flux/lora_checkpoint_utils.py ================================================ # ruff: noqa: N806 import os from pathlib import Path import peft import torch from diffusers import FluxTransformer2DModel from transformers import CLIPTextModel from invoke_training._shared.checkpoints.lora_checkpoint_utils import ( _convert_peft_state_dict_to_kohya_state_dict, load_multi_model_peft_checkpoint, save_multi_model_peft_checkpoint, ) from invoke_training._shared.checkpoints.serialization import save_state_dict FLUX_TRANSFORMER_TARGET_MODULES = [ # double blocks "attn.add_k_proj", "attn.add_q_proj", "attn.add_v_proj", "attn.to_add_out", "attn.to_k", "attn.to_q", "attn.to_v", "attn.to_out.0", "ff.net.0.proj", "ff.net.2.0", "ff_context.net.0.proj", "ff_context.net.2.0", # single blocks "attn.to_k", "attn.to_q", "attn.to_v", "proj_mlp", "proj_out", "proj_in", ] TEXT_ENCODER_TARGET_MODULES = ["fc1", "fc2", "q_proj", "k_proj", "v_proj", "out_proj"] # Module lists copied from diffusers training script. # These module lists will produce lighter, less expressive, LoRA models than the non-light versions. FLUX_TRANSFORMER_TARGET_MODULES_LIGHT = ["to_k", "to_q", "to_v", "to_out.0"] FLUX_TEXT_ENCODER_TARGET_MODULES_LIGHT = ["q_proj", "k_proj", "v_proj", "out_proj"] FLUX_PEFT_TRANSFORMER_KEY = "transformer" FLUX_PEFT_TEXT_ENCODER_1_KEY = "text_encoder_1" FLUX_PEFT_TEXT_ENCODER_2_KEY = "text_encoder_2" FLUX_KOHYA_TRANSFORMER_KEY = "lora_unet" FLUX_KOHYA_TEXT_ENCODER_1_KEY = "lora_clip" FLUX_KOHYA_TEXT_ENCODER_2_KEY = "lora_t5" FLUX_PEFT_TO_KOHYA_KEYS = { FLUX_PEFT_TRANSFORMER_KEY: FLUX_KOHYA_TRANSFORMER_KEY, FLUX_PEFT_TEXT_ENCODER_1_KEY: FLUX_KOHYA_TEXT_ENCODER_1_KEY, FLUX_PEFT_TEXT_ENCODER_2_KEY: FLUX_KOHYA_TEXT_ENCODER_2_KEY, } def save_flux_peft_checkpoint( checkpoint_dir: Path | str, transformer: peft.PeftModel | None, text_encoder_1: peft.PeftModel | None, text_encoder_2: peft.PeftModel | None, ): models = {} if transformer is not None: models[FLUX_PEFT_TRANSFORMER_KEY] = transformer if text_encoder_1 is not None: models[FLUX_PEFT_TEXT_ENCODER_1_KEY] = text_encoder_1 if text_encoder_2 is not None: models[FLUX_PEFT_TEXT_ENCODER_2_KEY] = text_encoder_2 save_multi_model_peft_checkpoint(checkpoint_dir=checkpoint_dir, models=models) def load_flux_peft_checkpoint( checkpoint_dir: Path | str, transformer: FluxTransformer2DModel, text_encoder_1: CLIPTextModel, text_encoder_2: CLIPTextModel, is_trainable: bool = False, ): models = load_multi_model_peft_checkpoint( checkpoint_dir=checkpoint_dir, models={ FLUX_PEFT_TRANSFORMER_KEY: transformer, FLUX_PEFT_TEXT_ENCODER_1_KEY: text_encoder_1, FLUX_PEFT_TEXT_ENCODER_2_KEY: text_encoder_2, }, is_trainable=is_trainable, raise_if_subdir_missing=False, ) return models[FLUX_PEFT_TRANSFORMER_KEY], models[FLUX_PEFT_TEXT_ENCODER_1_KEY], models[FLUX_PEFT_TEXT_ENCODER_2_KEY] def save_flux_kohya_checkpoint( checkpoint_path: Path, transformer: peft.PeftModel | None, text_encoder_1: peft.PeftModel | None, text_encoder_2: peft.PeftModel | None, ): kohya_prefixes = [] models = [] for kohya_prefix, peft_model in zip( [FLUX_KOHYA_TRANSFORMER_KEY, FLUX_KOHYA_TEXT_ENCODER_1_KEY], [transformer, text_encoder_1] ): if peft_model is not None: kohya_prefixes.append(kohya_prefix) models.append(peft_model) kohya_state_dict = _convert_peft_models_to_kohya_state_dict(kohya_prefixes=kohya_prefixes, models=models) checkpoint_path.parent.mkdir(parents=True, exist_ok=True) save_state_dict(kohya_state_dict, checkpoint_path) def convert_flux_peft_checkpoint_to_kohya_state_dict( in_checkpoint_dir: Path, out_checkpoint_file: Path, dtype: torch.dtype = torch.float32, ) -> dict[str, torch.Tensor]: """Convert Flux PEFT models to a Kohya-format LoRA state dict.""" # Get the immediate subdirectories of the checkpoint directory. We assume that each subdirectory is a PEFT model. peft_model_dirs = os.listdir(in_checkpoint_dir) peft_model_dirs = [in_checkpoint_dir / d for d in peft_model_dirs] # Convert to Path objects. peft_model_dirs = [d for d in peft_model_dirs if d.is_dir()] # Filter out non-directories. if len(peft_model_dirs) == 0: raise ValueError(f"No checkpoint files found in directory '{in_checkpoint_dir}'.") kohya_state_dict = {} for peft_model_dir in peft_model_dirs: if peft_model_dir.name in FLUX_PEFT_TO_KOHYA_KEYS: kohya_prefix = FLUX_PEFT_TO_KOHYA_KEYS[peft_model_dir.name] else: raise ValueError(f"Unrecognized checkpoint directory: '{peft_model_dir}'.") # Note: This logic to load the LoraConfig and weights directly is based on how it is done here: # https://github.com/huggingface/peft/blob/8665e2b5719faa4e4b91749ddec09442927b53e0/src/peft/peft_model.py#L672-L689 # This may need to be updated in the future to support other adapter types (LoKr, LoHa, etc.). # Also, I could see this interface breaking in the future. lora_config = peft.LoraConfig.from_pretrained(peft_model_dir) lora_weights = peft.utils.load_peft_weights(peft_model_dir, device="cpu") kohya_state_dict.update( _convert_peft_state_dict_to_kohya_state_dict( lora_config=lora_config, peft_state_dict=lora_weights, prefix=kohya_prefix, dtype=dtype ) ) save_state_dict(kohya_state_dict, out_checkpoint_file) def _convert_peft_models_to_kohya_state_dict( kohya_prefixes: list[str], models: list[peft.PeftModel] ) -> dict[str, torch.Tensor]: kohya_state_dict = {} default_adapter_name = "default" for kohya_prefix, peft_model in zip(kohya_prefixes, models, strict=True): lora_config = peft_model.peft_config[default_adapter_name] assert isinstance(lora_config, peft.LoraConfig) state_dict = peft.get_peft_model_state_dict(peft_model, adapter_name=default_adapter_name) if kohya_prefix == FLUX_KOHYA_TRANSFORMER_KEY: state_dict = convert_diffusers_to_flux_transformer_checkpoint(state_dict) kohya_state_dict.update( _convert_peft_state_dict_to_kohya_state_dict( lora_config=lora_config, peft_state_dict=state_dict, prefix=kohya_prefix, dtype=torch.float32, ) ) return kohya_state_dict def find_matching_key_prefix(state_dict, key_pattern): """ Find if any key in the state dictionary matches the given pattern. Args: state_dict: The state dictionary to search in key_pattern: The pattern to look for in keys Returns: The matching prefix if found, False otherwise """ base_prefix = key_pattern.split(".lora_A")[0].split(".lora_B")[0].split(".weight")[0] for key in state_dict.keys(): if base_prefix in key: return base_prefix return False def convert_layer_weights(target_dict, source_dict, source_pattern, target_pattern): """ Convert weights from source pattern to target pattern if they exist. Args: target_dict: Dictionary to store converted weights source_dict: Source dictionary containing weights source_pattern: Original key pattern to search for target_pattern: New key pattern to use Returns: Tuple of (updated target_dict, updated source_dict) """ if original_key := find_matching_key_prefix(source_dict, source_pattern): # Find all keys matching the pattern keys_to_convert = [k for k in source_dict.keys() if original_key in k] for key in keys_to_convert: # Create replacement key new_key = key.replace(original_key, target_pattern.replace(".weight", "")) # Transfer and remove from original target_dict[new_key] = source_dict.pop(key) return target_dict, source_dict def convert_double_transformer_block(target_dict, source_dict, prefix="", block_idx=0): """ Convert weights for a double transformer block. Args: target_dict: Dictionary to store converted weights source_dict: Source dictionary containing weights prefix: Prefix for the keys in the state dictionary block_idx: Block index Returns: Tuple of (updated target_dict, updated source_dict) """ block_prefix = f"transformer_blocks.{block_idx}." # Convert norms target_dict, source_dict = convert_layer_weights( target_dict, source_dict, f"{prefix}{block_prefix}norm1.linear.weight", f"double_blocks.{block_idx}.img_mod.lin.weight", ) target_dict, source_dict = convert_layer_weights( target_dict, source_dict, f"{prefix}{block_prefix}norm1_context.linear.weight", f"double_blocks.{block_idx}.txt_mod.lin.weight", ) # Convert attention weights by concatenating Q, K, V try: # Sample attention weights sample_q_A = source_dict.pop(f"{prefix}{block_prefix}attn.to_q.lora_A.weight") sample_q_B = source_dict.pop(f"{prefix}{block_prefix}attn.to_q.lora_B.weight") sample_k_A = source_dict.pop(f"{prefix}{block_prefix}attn.to_k.lora_A.weight") sample_k_B = source_dict.pop(f"{prefix}{block_prefix}attn.to_k.lora_B.weight") sample_v_A = source_dict.pop(f"{prefix}{block_prefix}attn.to_v.lora_A.weight") sample_v_B = source_dict.pop(f"{prefix}{block_prefix}attn.to_v.lora_B.weight") # Context attention weights context_q_A = source_dict.pop(f"{prefix}{block_prefix}attn.add_q_proj.lora_A.weight") context_q_B = source_dict.pop(f"{prefix}{block_prefix}attn.add_q_proj.lora_B.weight") context_k_A = source_dict.pop(f"{prefix}{block_prefix}attn.add_k_proj.lora_A.weight") context_k_B = source_dict.pop(f"{prefix}{block_prefix}attn.add_k_proj.lora_B.weight") context_v_A = source_dict.pop(f"{prefix}{block_prefix}attn.add_v_proj.lora_A.weight") context_v_B = source_dict.pop(f"{prefix}{block_prefix}attn.add_v_proj.lora_B.weight") # Concatenate Q, K, V for image and text target_dict[f"double_blocks.{block_idx}.img_attn.qkv.lora_A.weight"] = torch.cat( [sample_q_A, sample_k_A, sample_v_A], dim=0 ) target_dict[f"double_blocks.{block_idx}.img_attn.qkv.lora_B.weight"] = torch.cat( [sample_q_B, sample_k_B, sample_v_B], dim=0 ) target_dict[f"double_blocks.{block_idx}.txt_attn.qkv.lora_A.weight"] = torch.cat( [context_q_A, context_k_A, context_v_A], dim=0 ) target_dict[f"double_blocks.{block_idx}.txt_attn.qkv.lora_B.weight"] = torch.cat( [context_q_B, context_k_B, context_v_B], dim=0 ) except KeyError as e: print(f"Error processing attention weights for block {block_idx}: {e}") raise # Convert QK norms norm_keys = [ (f"{prefix}{block_prefix}attn.norm_q.weight", f"double_blocks.{block_idx}.img_attn.norm.query_norm.scale"), (f"{prefix}{block_prefix}attn.norm_k.weight", f"double_blocks.{block_idx}.img_attn.norm.key_norm.scale"), ( f"{prefix}{block_prefix}attn.norm_added_q.weight", f"double_blocks.{block_idx}.txt_attn.norm.query_norm.scale", ), (f"{prefix}{block_prefix}attn.norm_added_k.weight", f"double_blocks.{block_idx}.txt_attn.norm.key_norm.scale"), ] for src_key, target_key in norm_keys: target_dict, source_dict = convert_layer_weights(target_dict, source_dict, src_key, target_key) # Convert MLP weights mlp_keys = [ (f"{prefix}{block_prefix}ff.net.0.proj.weight", f"double_blocks.{block_idx}.img_mlp.0.weight"), (f"{prefix}{block_prefix}ff.net.2.weight", f"double_blocks.{block_idx}.img_mlp.2.weight"), (f"{prefix}{block_prefix}ff_context.net.0.proj.weight", f"double_blocks.{block_idx}.txt_mlp.0.weight"), (f"{prefix}{block_prefix}ff_context.net.2.weight", f"double_blocks.{block_idx}.txt_mlp.2.weight"), ] for src_key, target_key in mlp_keys: target_dict, source_dict = convert_layer_weights(target_dict, source_dict, src_key, target_key) # Convert output projections output_keys = [ (f"{prefix}{block_prefix}attn.to_out.0.weight", f"double_blocks.{block_idx}.img_attn.proj.weight"), (f"{prefix}{block_prefix}attn.to_add_out.weight", f"double_blocks.{block_idx}.txt_attn.proj.weight"), ] for src_key, target_key in output_keys: target_dict, source_dict = convert_layer_weights(target_dict, source_dict, src_key, target_key) return target_dict, source_dict def convert_single_transformer_block(target_dict, source_dict, prefix, block_idx): """ Convert weights for a single transformer block. Args: target_dict: Dictionary to store converted weights source_dict: Source dictionary containing weights prefix: Prefix for the keys in the state dictionary block_idx: Block index Returns: Tuple of (updated target_dict, updated source_dict) """ block_prefix = f"single_transformer_blocks.{block_idx}." # Convert norm target_dict, source_dict = convert_layer_weights( target_dict, source_dict, f"{prefix}{block_prefix}norm.linear.weight", f"single_blocks.{block_idx}.modulation.lin.weight", ) try: # Convert Q, K, V, MLP by concatenating q_A = source_dict.pop(f"{prefix}{block_prefix}attn.to_q.lora_A.weight") q_B = source_dict.pop(f"{prefix}{block_prefix}attn.to_q.lora_B.weight") k_A = source_dict.pop(f"{prefix}{block_prefix}attn.to_k.lora_A.weight") k_B = source_dict.pop(f"{prefix}{block_prefix}attn.to_k.lora_B.weight") v_A = source_dict.pop(f"{prefix}{block_prefix}attn.to_v.lora_A.weight") v_B = source_dict.pop(f"{prefix}{block_prefix}attn.to_v.lora_B.weight") mlp_A = source_dict.pop(f"{prefix}{block_prefix}proj_mlp.lora_A.weight") mlp_B = source_dict.pop(f"{prefix}{block_prefix}proj_mlp.lora_B.weight") target_dict[f"single_blocks.{block_idx}.linear1.lora_A.weight"] = torch.cat([q_A, k_A, v_A, mlp_A], dim=0) target_dict[f"single_blocks.{block_idx}.linear1.lora_B.weight"] = torch.cat([q_B, k_B, v_B, mlp_B], dim=0) except KeyError as e: print(f"Error processing attention weights for single block {block_idx}: {e}") raise # Convert output projection target_dict, source_dict = convert_layer_weights( target_dict, source_dict, f"{prefix}{block_prefix}proj_out.weight", f"single_blocks.{block_idx}.linear2.weight", ) return target_dict, source_dict def convert_embedding_layers(target_dict, source_dict, prefix, has_guidance=True): """ Convert time, text, guidance, and context embedding layers. Args: target_dict: Dictionary to store converted weights source_dict: Source dictionary containing weights prefix: Prefix for the keys in the state dictionary has_guidance: Whether the model has guidance embedding Returns: Tuple of (updated target_dict, updated source_dict) """ # Convert time embedding target_dict, source_dict = convert_layer_weights( target_dict, source_dict, f"{prefix}time_text_embed.timestep_embedder.linear_1.weight", "time_in.in_layer.weight", ) # Convert text embedding text_embed_keys = [ (f"{prefix}time_text_embed.text_embedder.linear_1.weight", "vector_in.in_layer.weight"), (f"{prefix}time_text_embed.text_embedder.linear_2.weight", "vector_in.out_layer.weight"), ] for src_key, target_key in text_embed_keys: target_dict, source_dict = convert_layer_weights(target_dict, source_dict, src_key, target_key) # Convert guidance embedding if needed if has_guidance: guidance_keys = [ (f"{prefix}time_text_embed.guidance_embedder.linear_1.weight", "guidance_in.in_layer.weight"), (f"{prefix}time_text_embed.guidance_embedder.linear_2.weight", "guidance_in.out_layer.weight"), ] for src_key, target_key in guidance_keys: target_dict, source_dict = convert_layer_weights(target_dict, source_dict, src_key, target_key) # Convert context and image embedders embed_keys = [ (f"{prefix}context_embedder.weight", "txt_in.weight"), (f"{prefix}x_embedder.weight", "img_in.weight"), ] for src_key, target_key in embed_keys: target_dict, source_dict = convert_layer_weights(target_dict, source_dict, src_key, target_key) return target_dict, source_dict def convert_output_layers(target_dict, source_dict, prefix): """ Convert final output layers. Args: target_dict: Dictionary to store converted weights source_dict: Source dictionary containing weights prefix: Prefix for the keys in the state dictionary Returns: Tuple of (updated target_dict, updated source_dict) """ output_keys = [ (f"{prefix}proj_out.weight", "final_layer.linear.weight"), (f"{prefix}proj_out.bias", "final_layer.linear.bias"), (f"{prefix}norm_out.linear.weight", "final_layer.adaLN_modulation.1.weight"), ] for src_key, target_key in output_keys: target_dict, source_dict = convert_layer_weights(target_dict, source_dict, src_key, target_key) return target_dict, source_dict def convert_diffusers_to_flux_transformer_checkpoint( diffusers_state_dict, num_layers=19, num_single_layers=38, has_guidance=True, old_prefix="base_model.model.", new_prefix=FLUX_KOHYA_TRANSFORMER_KEY, ): """ Convert a diffusers state dictionary to flux transformer checkpoint format. Args: diffusers_state_dict: Source diffusers state dictionary num_layers: Number of double transformer layers num_single_layers: Number of single transformer layers has_guidance: Whether the model has guidance embedding prefix: Prefix for keys in the source dictionary Returns: A new state dictionary in flux transformer format """ # Create a new state dictionary flux_state_dict = {} # Convert embedding layers flux_state_dict, diffusers_state_dict = convert_embedding_layers( flux_state_dict, diffusers_state_dict, old_prefix, has_guidance ) # Convert double transformer blocks for i in range(num_layers): flux_state_dict, diffusers_state_dict = convert_double_transformer_block( flux_state_dict, diffusers_state_dict, old_prefix, i ) # Convert single transformer blocks for i in range(num_single_layers): flux_state_dict, diffusers_state_dict = convert_single_transformer_block( flux_state_dict, diffusers_state_dict, old_prefix, i ) # Convert output layers flux_state_dict, diffusers_state_dict = convert_output_layers(flux_state_dict, diffusers_state_dict, old_prefix) # Check for leftover keys if diffusers_state_dict: print(f"Unexpected keys: {list(diffusers_state_dict.keys())}") # Replace the old prefix with the new prefix keys = list(flux_state_dict.keys()) for key in keys: new_key = f"{new_prefix}.{key}" flux_state_dict[new_key] = flux_state_dict.pop(key) return flux_state_dict ================================================ FILE: src/invoke_training/_shared/flux/model_loading_utils.py ================================================ import logging from enum import Enum import torch from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer class PipelineVersionEnum(Enum): FLUX = "FLUX" def load_pipeline( logger: logging.Logger, model_name_or_path: str = "black-forest-labs/FLUX.1-dev", pipeline_version: PipelineVersionEnum = PipelineVersionEnum.FLUX, transformer_path: str | None = None, text_encoder_1_path: str | None = None, text_encoder_2_path: str | None = None, torch_dtype: torch.dtype | None = None, ) -> FluxPipeline: """Load a Flux pipeline with optional custom components from .safetensors files. Args: logger: Logger instance model_name_or_path: Base model path or repository pipeline_version: Pipeline version (currently only FLUX supported) transformer_path: Path to custom transformer .safetensors file text_encoder_1_path: Path to custom CLIP text encoder .safetensors file text_encoder_2_path: Path to custom T5 text encoder .safetensors file torch_dtype: Desired dtype for the models Returns: FluxPipeline: Configured pipeline with custom components if specified """ if pipeline_version != PipelineVersionEnum.FLUX: raise ValueError(f"Invalid pipeline version: {pipeline_version}") # Prepare kwargs for from_pretrained kwargs = {"torch_dtype": torch_dtype} # Add components only if custom paths are provided if transformer_path is not None: # load_model_from_file_or_pretrained(FluxTransformer2DModel, transformer_path, torch_dtype=torch_dtype, # use_safetensors=True, subfolder="transformer") kwargs["transformer"] = FluxTransformer2DModel.from_pretrained( transformer_path, torch_dtype=torch_dtype, ) logger.info(f"Loading custom transformer from {transformer_path}") if text_encoder_1_path is not None: logger.info(f"Loading custom CLIP text encoder from {text_encoder_1_path}") kwargs["text_encoder"] = CLIPTextModel.from_pretrained(text_encoder_1_path, torch_dtype=torch_dtype) if text_encoder_2_path is not None: logger.info(f"Loading custom T5 text encoder from {text_encoder_2_path}") kwargs["text_encoder_2"] = T5EncoderModel.from_pretrained(text_encoder_2_path, torch_dtype=torch_dtype) # Load the pipeline with any custom components pipeline = FluxPipeline.from_pretrained(model_name_or_path, **kwargs) return pipeline def load_models_flux( logger: logging.Logger, model_name_or_path: str = "black-forest-labs/FLUX.1-dev", dtype: torch.dtype | None = None, transformer_path: str | None = None, text_encoder_1_path: str | None = None, text_encoder_2_path: str | None = None, ) -> tuple[CLIPTokenizer, FlowMatchEulerDiscreteScheduler, CLIPTextModel, AutoencoderKL, FluxTransformer2DModel]: """Load all models required for training from disk, transfer them to the target training device and cast their weight dtypes. Args: logger: Logger instance model_name_or_path: Base model path or repository dtype: Desired dtype for the models transformer_path: Path to custom transformer .safetensors file text_encoder_1_path: Path to custom CLIP text encoder .safetensors file text_encoder_2_path: Path to custom T5 text encoder .safetensors file """ pipeline: FluxPipeline = load_pipeline( logger=logger, model_name_or_path=model_name_or_path, pipeline_version=PipelineVersionEnum.FLUX, transformer_path=transformer_path, text_encoder_1_path=text_encoder_1_path, text_encoder_2_path=text_encoder_2_path, torch_dtype=dtype, ) # Tokenizers and text encoders. tokenizer_1: CLIPTokenizer = pipeline.tokenizer text_encoder_1: CLIPTextModel = pipeline.text_encoder tokenizer_2: T5Tokenizer = pipeline.tokenizer_2 text_encoder_2: T5EncoderModel = pipeline.text_encoder_2 # Transformer and Scheduler transformer: FluxTransformer2DModel = pipeline.transformer noise_scheduler: FlowMatchEulerDiscreteScheduler = pipeline.scheduler # Decoder vae: AutoencoderKL = pipeline.vae # Log component status logger.info( f"Pipeline components loaded: tokenizer_1={tokenizer_1 is not None}, " f"text_encoder_1={text_encoder_1 is not None}, " f"tokenizer_2={tokenizer_2 is not None}, " f"text_encoder_2={text_encoder_2 is not None}, " f"transformer={transformer is not None}, " f"vae={vae is not None}" ) # Check for None components if text_encoder_1 is None: raise ValueError( "text_encoder_1 failed to load. " "Check if you have access to the model repository and are properly authenticated." ) if text_encoder_2 is None: raise ValueError( "text_encoder_2 failed to load. " "Check if you have access to the model repository and are properly authenticated." ) if transformer is None: raise ValueError( "transformer failed to load. " "Check if you have access to the model repository and are properly authenticated." ) if vae is None: raise ValueError( "vae failed to load. Check if you have access to the model repository and are properly authenticated." ) # Disable gradient calculation for model weights to save memory. text_encoder_1.requires_grad_(False) text_encoder_2.requires_grad_(False) vae.requires_grad_(False) transformer.requires_grad_(False) if dtype is not None: text_encoder_1 = text_encoder_1.to(dtype=dtype) text_encoder_2 = text_encoder_2.to(dtype=dtype) vae = vae.to(dtype=dtype) transformer = transformer.to(dtype=dtype) # Put models in 'eval' mode. text_encoder_1.eval() text_encoder_2.eval() vae.eval() transformer.eval() return tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, transformer ================================================ FILE: src/invoke_training/_shared/flux/validation.py ================================================ import logging import os import numpy as np import torch import torch.utils.data from accelerate import Accelerator from accelerate.hooks import remove_hook_from_module from diffusers import ( AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel, ) from peft import PeftModel from transformers import CLIPTextModel, CLIPTokenizer from invoke_training._shared.data.utils.resolution import Resolution from invoke_training.pipelines.callbacks import PipelineCallbacks, ValidationImage, ValidationImages from invoke_training.pipelines.flux.lora.config import FluxLoraConfig NUM_INFERENCE_STEPS = 20 def generate_validation_images_flux( # noqa: C901 epoch: int, step: int, out_dir: str, accelerator: Accelerator, vae: AutoencoderKL, text_encoder_1: CLIPTextModel, text_encoder_2: CLIPTextModel, tokenizer_1: CLIPTokenizer, tokenizer_2: CLIPTokenizer, noise_scheduler: FlowMatchEulerDiscreteScheduler, transformer: FluxTransformer2DModel | PeftModel, config: FluxLoraConfig, logger: logging.Logger, callbacks: list[PipelineCallbacks] | None = None, ): """Generate validation images for the purpose of tracking image generation behaviour on fixed prompts throughout training. """ # Record original model devices so that we can restore this state after running the pipeline with CPU model # offloading. transformer_device = transformer.device vae_device = vae.device text_encoder_1_device = text_encoder_1.device text_encoder_2_device = text_encoder_2.device # Create pipeline. pipeline = FluxPipeline( vae=vae, text_encoder=text_encoder_1, text_encoder_2=text_encoder_2, tokenizer=tokenizer_1, tokenizer_2=tokenizer_2, transformer=transformer, scheduler=noise_scheduler, ) if config.enable_cpu_offload_during_validation: pipeline.enable_model_cpu_offload(accelerator.device.index or 0) else: pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) validation_resolution = Resolution.parse(config.data_loader.resolution) validation_images = ValidationImages(images=[], epoch=epoch, step=step) validation_step_dir = os.path.join(out_dir, "validation", f"epoch_{epoch:0>8}-step_{step:0>8}") logger.info(f"Generating validation images ({validation_step_dir}).") # Run inference. with torch.no_grad(): for prompt_idx in range(len(config.validation_prompts)): positive_prompt = config.validation_prompts[prompt_idx] negative_prompt = None logger.info(f"Validation prompt {prompt_idx}, pos: '{positive_prompt}', neg: '{negative_prompt or ''}'") generator = torch.Generator(device=accelerator.device) if config.seed is not None: generator = generator.manual_seed(config.seed) images = [] for _ in range(config.num_validation_images_per_prompt): with accelerator.autocast(): images.append( pipeline( positive_prompt, num_inference_steps=NUM_INFERENCE_STEPS, generator=generator, height=validation_resolution.height, width=validation_resolution.width, negative_prompt=negative_prompt, ).images[0] ) # Save images to disk. validation_prompt_dir = os.path.join(validation_step_dir, f"prompt_{prompt_idx:0>4}") os.makedirs(validation_prompt_dir) for image_idx, image in enumerate(images): image_path = os.path.join(validation_prompt_dir, f"{image_idx:0>4}.jpg") validation_images.images.append( ValidationImage(file_path=image_path, prompt=positive_prompt, image_idx=image_idx) ) image.save(image_path) # Log images to trackers. Currently, only tensorboard is supported. for tracker in accelerator.trackers: if tracker.name == "tensorboard": np_images = np.stack([np.asarray(img) for img in images]) tracker.writer.add_images( f"validation (prompt {prompt_idx})", np_images, step, dataformats="NHWC", ) del pipeline torch.cuda.empty_cache() for model in [transformer, vae, text_encoder_1, text_encoder_2]: remove_hook_from_module(model) # Restore models to original devices. transformer.to(transformer_device) vae.to(vae_device) text_encoder_1.to(text_encoder_1_device) text_encoder_2.to(text_encoder_2_device) # Run callbacks. if callbacks is not None: for cb in callbacks: cb.on_save_validation_images(images=validation_images) ================================================ FILE: src/invoke_training/_shared/optimizer/__init__.py ================================================ ================================================ FILE: src/invoke_training/_shared/optimizer/optimizer_utils.py ================================================ import torch from prodigyopt import Prodigy from invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig, ProdigyOptimizerConfig def initialize_optimizer( config: AdamOptimizerConfig | ProdigyOptimizerConfig, trainable_params: list ) -> torch.optim.Optimizer: """Initialize an optimizer based on the provided config.""" if config.optimizer_type == "AdamW": adam_cls = torch.optim.AdamW if config.use_8bit: try: import bitsandbytes # noqa: F401 except ImportError: raise ImportError( "bitsandbytes is not installed. bitsandbytes is required to use the 8-bit Adam optimizer. Install " 'it by running `pip install ".[bitsandbytes]"`.' ) adam_cls = bitsandbytes.optim.AdamW8bit optimizer = adam_cls( trainable_params, lr=config.learning_rate, betas=(config.beta1, config.beta2), weight_decay=config.weight_decay, eps=config.epsilon, ) elif config.optimizer_type == "Prodigy": optimizer = Prodigy( trainable_params, lr=config.learning_rate, weight_decay=config.weight_decay, use_bias_correction=config.use_bias_correction, safeguard_warmup=config.safeguard_warmup, ) else: raise ValueError(f"'{config.optimizer_type}' is not a supported optimizer.") return optimizer ================================================ FILE: src/invoke_training/_shared/stable_diffusion/__init__.py ================================================ ================================================ FILE: src/invoke_training/_shared/stable_diffusion/base_model_version.py ================================================ from enum import Enum from transformers import PretrainedConfig class BaseModelVersionEnum(Enum): STABLE_DIFFUSION_V1 = 1 STABLE_DIFFUSION_V2 = 2 STABLE_DIFFUSION_SDXL_BASE = 3 STABLE_DIFFUSION_SDXL_REFINER = 4 def get_base_model_version( diffusers_model_name: str, revision: str = "main", local_files_only: bool = True ) -> BaseModelVersionEnum: """Returns the `BaseModelVersionEnum` of a diffusers model. Args: diffusers_model_name (str): The diffusers model name (on Hugging Face Hub). revision (str, optional): The model revision (branch or commit hash). Defaults to "main". Raises: Exception: If the base model version can not be determined. Returns: BaseModelVersionEnum: The detected base model version. """ unet_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path=diffusers_model_name, revision=revision, subfolder="unet", local_files_only=local_files_only, ) # This logic was copied from # https://github.com/invoke-ai/InvokeAI/blob/e77400ab62d24acbdf2f48a7427705e7b8b97e4a/invokeai/backend/model_management/model_probe.py#L412-L421 # This seems fragile. If you see this and know of a better way to detect the base model version, your contribution # would be welcome. if unet_config.cross_attention_dim == 768: return BaseModelVersionEnum.STABLE_DIFFUSION_V1 elif unet_config.cross_attention_dim == 1024: return BaseModelVersionEnum.STABLE_DIFFUSION_V2 elif unet_config.cross_attention_dim == 1280: return BaseModelVersionEnum.STABLE_DIFFUSION_SDXL_REFINER elif unet_config.cross_attention_dim == 2048: return BaseModelVersionEnum.STABLE_DIFFUSION_SDXL_BASE else: raise Exception( "Failed to determine base model version. UNet cross_attention_dim has unexpected value: " f"'{unet_config.cross_attention_dim}'." ) def check_base_model_version( allowed_versions: set[BaseModelVersionEnum], diffusers_model_name: str, revision: str = "main", local_files_only: bool = True, ): """Helper function that checks if a diffusers model is compatible with a set of base model versions. Args: allowed_versions (set[BaseModelVersionEnum]): The set of allowed base model versions. diffusers_model_name (str): The diffusers model name (on Hugging Face Hub) to check. revision (str, optional): The model revision (branch or commit hash). Defaults to "main". Raises: ValueError: If the model has an unsupported version. """ version = get_base_model_version(diffusers_model_name, revision, local_files_only) if version not in allowed_versions: raise ValueError( f"Model '{diffusers_model_name}' (revision='{revision}') has an unsupported version: '{version.name}'. " f"Supported versions: {[v.name for v in allowed_versions]}." ) ================================================ FILE: src/invoke_training/_shared/stable_diffusion/checkpoint_utils.py ================================================ from pathlib import Path import torch from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionXLPipeline, UNet2DConditionModel from transformers import CLIPTextModel, CLIPTokenizer def save_sdxl_diffusers_unet_checkpoint( checkpoint_path: Path | str, unet: UNet2DConditionModel, save_dtype: torch.dtype ): # Record original device and dtype so that we can restore it afterward. model_list = [unet] original_devices = [model.device for model in model_list] original_dtypes = [model.dtype for model in model_list] # Save UNet. unet.to(dtype=save_dtype) unet.save_pretrained(Path(checkpoint_path) / "unet") # Restore original device/dtype. for model, device, dtype in zip(model_list, original_devices, original_dtypes, strict=True): model.to(device=device, dtype=dtype) def save_sdxl_diffusers_checkpoint( checkpoint_path: Path | str, vae: AutoencoderKL, text_encoder_1: CLIPTextModel, text_encoder_2: CLIPTextModel, tokenizer_1: CLIPTokenizer, tokenizer_2: CLIPTokenizer, noise_scheduler: DDPMScheduler, unet: UNet2DConditionModel, save_dtype: torch.dtype, ): # Record original device and dtype so that we can restore it afterward. # TODO(ryand): This method of restoring original device/dtype is a bit naive. It does not handle mixed precisions # within a model, and results in a loss of precision if the save_dtype is lower precision than the model dtype. We # may need to revisit this. model_list = [vae, text_encoder_1, text_encoder_2, unet] original_devices = [model.device for model in model_list] original_dtypes = [model.dtype for model in model_list] # Create pipeline. pipeline = StableDiffusionXLPipeline( vae=vae, text_encoder=text_encoder_1, text_encoder_2=text_encoder_2, tokenizer=tokenizer_1, tokenizer_2=tokenizer_2, unet=unet, scheduler=noise_scheduler, ) pipeline = pipeline.to(dtype=save_dtype) # Save pipeline. pipeline.save_pretrained(checkpoint_path) # Restore original device/dtype. for model, device, dtype in zip(model_list, original_devices, original_dtypes, strict=True): model.to(device=device, dtype=dtype) ================================================ FILE: src/invoke_training/_shared/stable_diffusion/lora_checkpoint_utils.py ================================================ import os from pathlib import Path import peft import torch from diffusers import UNet2DConditionModel from transformers import CLIPTextModel from invoke_training._shared.checkpoints.lora_checkpoint_utils import ( _convert_peft_models_to_kohya_state_dict, _convert_peft_state_dict_to_kohya_state_dict, load_multi_model_peft_checkpoint, save_multi_model_peft_checkpoint, ) from invoke_training._shared.checkpoints.serialization import save_state_dict # Copied from https://github.com/huggingface/peft/blob/8665e2b5719faa4e4b91749ddec09442927b53e0/examples/stable_diffusion/train_dreambooth.py#L49C1-L65C87 UNET_TARGET_MODULES = [ "to_q", "to_k", "to_v", "proj", "proj_in", "proj_out", "conv", "conv1", "conv2", "conv_shortcut", "to_out.0", "time_emb_proj", "ff.net.2", ] TEXT_ENCODER_TARGET_MODULES = ["fc1", "fc2", "q_proj", "k_proj", "v_proj", "out_proj"] # Module lists copied from diffusers training script. # These module lists will produce lighter, less expressive, LoRA models than the non-light versions. UNET_TARGET_MODULES_LIGHT = ["to_k", "to_q", "to_v", "to_out.0"] TEXT_ENCODER_TARGET_MODULES_LIGHT = ["q_proj", "k_proj", "v_proj", "out_proj"] SD_PEFT_UNET_KEY = "unet" SD_PEFT_TEXT_ENCODER_KEY = "text_encoder" SDXL_PEFT_UNET_KEY = "unet" SDXL_PEFT_TEXT_ENCODER_1_KEY = "text_encoder_1" SDXL_PEFT_TEXT_ENCODER_2_KEY = "text_encoder_2" SD_KOHYA_UNET_KEY = "lora_unet" SD_KOHYA_TEXT_ENCODER_KEY = "lora_te" SDXL_KOHYA_UNET_KEY = "lora_unet" SDXL_KOHYA_TEXT_ENCODER_1_KEY = "lora_te1" SDXL_KOHYA_TEXT_ENCODER_2_KEY = "lora_te2" SD_PEFT_TO_KOHYA_KEYS = { SD_PEFT_UNET_KEY: SD_KOHYA_UNET_KEY, SD_PEFT_TEXT_ENCODER_KEY: SD_KOHYA_TEXT_ENCODER_KEY, } SDXL_PEFT_TO_KOHYA_KEYS = { SDXL_PEFT_UNET_KEY: SDXL_KOHYA_UNET_KEY, SDXL_PEFT_TEXT_ENCODER_1_KEY: SDXL_KOHYA_TEXT_ENCODER_1_KEY, SDXL_PEFT_TEXT_ENCODER_2_KEY: SDXL_KOHYA_TEXT_ENCODER_2_KEY, } def save_sd_peft_checkpoint( checkpoint_dir: Path | str, unet: peft.PeftModel | None, text_encoder: peft.PeftModel | None ): models = {} if unet is not None: models[SD_PEFT_UNET_KEY] = unet if text_encoder is not None: models[SD_PEFT_TEXT_ENCODER_KEY] = text_encoder save_multi_model_peft_checkpoint(checkpoint_dir=checkpoint_dir, models=models) def load_sd_peft_checkpoint( checkpoint_dir: Path | str, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, is_trainable: bool = False ): models = load_multi_model_peft_checkpoint( checkpoint_dir=checkpoint_dir, models={SD_PEFT_UNET_KEY: unet, SD_PEFT_TEXT_ENCODER_KEY: text_encoder}, is_trainable=is_trainable, raise_if_subdir_missing=False, ) return models[SD_PEFT_UNET_KEY], models[SD_PEFT_TEXT_ENCODER_KEY] def save_sdxl_peft_checkpoint( checkpoint_dir: Path | str, unet: peft.PeftModel | None, text_encoder_1: peft.PeftModel | None, text_encoder_2: peft.PeftModel | None, ): models = {} if unet is not None: models[SDXL_PEFT_UNET_KEY] = unet if text_encoder_1 is not None: models[SDXL_PEFT_TEXT_ENCODER_1_KEY] = text_encoder_1 if text_encoder_2 is not None: models[SDXL_PEFT_TEXT_ENCODER_2_KEY] = text_encoder_2 save_multi_model_peft_checkpoint(checkpoint_dir=checkpoint_dir, models=models) def load_sdxl_peft_checkpoint( checkpoint_dir: Path | str, unet: UNet2DConditionModel, text_encoder_1: CLIPTextModel, text_encoder_2: CLIPTextModel, is_trainable: bool = False, ): models = load_multi_model_peft_checkpoint( checkpoint_dir=checkpoint_dir, models={ SDXL_PEFT_UNET_KEY: unet, SDXL_PEFT_TEXT_ENCODER_1_KEY: text_encoder_1, SDXL_PEFT_TEXT_ENCODER_2_KEY: text_encoder_2, }, is_trainable=is_trainable, raise_if_subdir_missing=False, ) return models[SDXL_PEFT_UNET_KEY], models[SDXL_PEFT_TEXT_ENCODER_1_KEY], models[SDXL_PEFT_TEXT_ENCODER_2_KEY] def save_sd_kohya_checkpoint(checkpoint_path: Path, unet: peft.PeftModel | None, text_encoder: peft.PeftModel | None): kohya_prefixes = [] models = [] for kohya_prefix, peft_model in zip([SD_KOHYA_UNET_KEY, SD_KOHYA_TEXT_ENCODER_KEY], [unet, text_encoder]): if peft_model is not None: kohya_prefixes.append(kohya_prefix) models.append(peft_model) kohya_state_dict = _convert_peft_models_to_kohya_state_dict(kohya_prefixes=kohya_prefixes, models=models) checkpoint_path.parent.mkdir(parents=True, exist_ok=True) save_state_dict(kohya_state_dict, checkpoint_path) def save_sdxl_kohya_checkpoint( checkpoint_path: Path, unet: peft.PeftModel | None, text_encoder_1: peft.PeftModel | None, text_encoder_2: peft.PeftModel | None, ): kohya_prefixes = [] models = [] for kohya_prefix, peft_model in zip( [SDXL_KOHYA_UNET_KEY, SDXL_KOHYA_TEXT_ENCODER_1_KEY, SDXL_KOHYA_TEXT_ENCODER_2_KEY], [unet, text_encoder_1, text_encoder_2], ): if peft_model is not None: kohya_prefixes.append(kohya_prefix) models.append(peft_model) kohya_state_dict = _convert_peft_models_to_kohya_state_dict(kohya_prefixes=kohya_prefixes, models=models) checkpoint_path.parent.mkdir(parents=True, exist_ok=True) save_state_dict(kohya_state_dict, checkpoint_path) def convert_sd_peft_checkpoint_to_kohya_state_dict( in_checkpoint_dir: Path, out_checkpoint_file: Path, dtype: torch.dtype = torch.float32, ) -> dict[str, torch.Tensor]: """Convert SD v1 or SDXL PEFT models to a Kohya-format LoRA state dict.""" # Get the immediate subdirectories of the checkpoint directory. We assume that each subdirectory is a PEFT model. peft_model_dirs = os.listdir(in_checkpoint_dir) peft_model_dirs = [in_checkpoint_dir / d for d in peft_model_dirs] # Convert to Path objects. peft_model_dirs = [d for d in peft_model_dirs if d.is_dir()] # Filter out non-directories. if len(peft_model_dirs) == 0: raise ValueError(f"No checkpoint files found in directory '{in_checkpoint_dir}'.") kohya_state_dict = {} for peft_model_dir in peft_model_dirs: if peft_model_dir.name in SD_PEFT_TO_KOHYA_KEYS: kohya_prefix = SD_PEFT_TO_KOHYA_KEYS[peft_model_dir.name] elif peft_model_dir.name in SDXL_PEFT_TO_KOHYA_KEYS: kohya_prefix = SDXL_PEFT_TO_KOHYA_KEYS[peft_model_dir.name] else: raise ValueError(f"Unrecognized checkpoint directory: '{peft_model_dir}'.") # Note: This logic to load the LoraConfig and weights directly is based on how it is done here: # https://github.com/huggingface/peft/blob/8665e2b5719faa4e4b91749ddec09442927b53e0/src/peft/peft_model.py#L672-L689 # This may need to be updated in the future to support other adapter types (LoKr, LoHa, etc.). # Also, I could see this interface breaking in the future. lora_config = peft.LoraConfig.from_pretrained(peft_model_dir) lora_weights = peft.utils.load_peft_weights(peft_model_dir, device="cpu") kohya_state_dict.update( _convert_peft_state_dict_to_kohya_state_dict( lora_config=lora_config, peft_state_dict=lora_weights, prefix=kohya_prefix, dtype=dtype ) ) save_state_dict(kohya_state_dict, out_checkpoint_file) ================================================ FILE: src/invoke_training/_shared/stable_diffusion/min_snr_weighting.py ================================================ import torch from diffusers import DDPMScheduler def compute_snr(noise_scheduler: DDPMScheduler, timesteps: torch.Tensor): """ Computes SNR. Adapted from: https://github.com/huggingface/diffusers/blob/ea9dc3fa90c70c7cd825ca2346a31153e08b5367/src/diffusers/training_utils.py#L40 Which was originally based on: https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 """ alphas_cumprod = noise_scheduler.alphas_cumprod sqrt_alphas_cumprod = alphas_cumprod**0.5 sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 # Expand the tensors. # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] alpha = sqrt_alphas_cumprod.expand(timesteps.shape) sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) # Compute SNR. snr = (alpha / sigma) ** 2 return snr ================================================ FILE: src/invoke_training/_shared/stable_diffusion/model_loading_utils.py ================================================ import logging import os import typing from enum import Enum import torch from diffusers import ( AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, StableDiffusionXLPipeline, UNet2DConditionModel, ) from transformers import CLIPTextModel, CLIPTokenizer from invoke_training._shared.checkpoints.serialization import load_state_dict HF_VARIANT_FALLBACKS = [None, "fp16"] class PipelineVersionEnum(Enum): SD = "SD" SDXL = "SDXL" def load_pipeline( logger: logging.Logger, model_name_or_path: str, pipeline_version: PipelineVersionEnum, torch_dtype: torch.dtype = None, variant: str | None = None, ) -> typing.Union[StableDiffusionPipeline, StableDiffusionXLPipeline]: """Load a Stable Diffusion pipeline from disk. Args: model_name_or_path (str): The name or path of the pipeline to load. Can be in diffusers format, or a single stable diffusion checkpoint file. (E.g. 'runwayml/stable-diffusion-v1-5', 'stabilityai/stable-diffusion-xl-base-1.0', '/path/to/realisticVisionV51_v51VAE.safetensors', etc. ) pipeline_version (PipelineVersionEnum): The pipeline version. variant (str | None): The Hugging Face Hub variant. Only applies if `model_name_or_path` is a HF Hub model name. Returns: typing.Union[StableDiffusionPipeline, StableDiffusionXLPipeline]: The loaded pipeline. """ if pipeline_version == PipelineVersionEnum.SD: pipeline_class = StableDiffusionPipeline elif pipeline_version == PipelineVersionEnum.SDXL: pipeline_class = StableDiffusionXLPipeline else: raise ValueError(f"Unsupported pipeline_version: '{pipeline_version}'.") if os.path.isfile(model_name_or_path): return pipeline_class.from_single_file( model_name_or_path, torch_dtype=torch_dtype, safety_checker=None, feature_extractor=None, ) return from_pretrained_with_variant_fallback( logger=logger, model_class=pipeline_class, model_name_or_path=model_name_or_path, torch_dtype=torch_dtype, variant=variant, # kwargs safety_checker=None, requires_safety_checker=False, ) ModelT = typing.TypeVar("ModelT") def from_pretrained_with_variant_fallback( logger: logging.Logger, model_class: typing.Type[ModelT], model_name_or_path: str, torch_dtype: torch.dtype | None = None, variant: str | None = None, **kwargs, ) -> ModelT: """A wrapper for .from_pretrained() that tries multiple variants if the initial one fails.""" variants_to_try = [variant] + [v for v in HF_VARIANT_FALLBACKS if v != variant] model: ModelT | None = None for variant_to_try in variants_to_try: if variant_to_try != variant: logger.warning(f"Trying fallback variant '{variant_to_try}'.") try: model = model_class.from_pretrained( model_name_or_path, torch_dtype=torch_dtype, variant=variant_to_try, **kwargs, ) except (OSError, ValueError) as e: error_str = str(e) if "no file named" in error_str or "no such modeling files are available" in error_str: # Ok; we'll try the variant fallbacks. logger.warning(f"Failed to load '{model_name_or_path}' with variant '{variant_to_try}'. Error: {e}.") else: raise if model is not None: break if model is None: raise RuntimeError(f"Failed to load model '{model_name_or_path}'.") return model def load_models_sd( logger: logging.Logger, model_name_or_path: str, hf_variant: str | None = None, base_embeddings: dict[str, str] = None, dtype: torch.dtype | None = None, ) -> tuple[CLIPTokenizer, DDPMScheduler, CLIPTextModel, AutoencoderKL, UNet2DConditionModel]: """Load all models required for training from disk, transfer them to the target training device and cast their weight dtypes. """ base_embeddings = base_embeddings or {} pipeline: StableDiffusionPipeline = load_pipeline( logger=logger, model_name_or_path=model_name_or_path, pipeline_version=PipelineVersionEnum.SD, variant=hf_variant, ) for token, embedding_path in base_embeddings.items(): pipeline.load_textual_inversion(embedding_path, token=token) # Extract sub-models from the pipeline. tokenizer: CLIPTokenizer = pipeline.tokenizer text_encoder: CLIPTextModel = pipeline.text_encoder vae: AutoencoderKL = pipeline.vae unet: UNet2DConditionModel = pipeline.unet noise_scheduler = DDPMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False, steps_offset=1, ) # Disable gradient calculation for model weights to save memory. text_encoder.requires_grad_(False) vae.requires_grad_(False) unet.requires_grad_(False) if dtype is not None: text_encoder = text_encoder.to(dtype=dtype) vae = vae.to(dtype=dtype) unet = unet.to(dtype=dtype) # Put models in 'eval' mode. text_encoder.eval() vae.eval() unet.eval() return tokenizer, noise_scheduler, text_encoder, vae, unet def load_models_sdxl( logger: logging.Logger, model_name_or_path: str, hf_variant: str | None = None, vae_model: str | None = None, base_embeddings: dict[str, str] = None, dtype: torch.dtype | None = None, ) -> tuple[ CLIPTokenizer, CLIPTokenizer, DDPMScheduler, CLIPTextModel, CLIPTextModel, AutoencoderKL, UNet2DConditionModel, ]: """Load all models required for training, transfer them to the target training device and cast their weight dtypes. """ base_embeddings = base_embeddings or {} pipeline: StableDiffusionXLPipeline = load_pipeline( logger=logger, model_name_or_path=model_name_or_path, pipeline_version=PipelineVersionEnum.SDXL, variant=hf_variant, ) for token, embedding_path in base_embeddings.items(): state_dict = load_state_dict(embedding_path) pipeline.load_textual_inversion( state_dict["clip_l"], token=token, text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer, ) pipeline.load_textual_inversion( state_dict["clip_g"], token=token, text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2, ) # Extract sub-models from the pipeline. tokenizer_1: CLIPTokenizer = pipeline.tokenizer tokenizer_2: CLIPTokenizer = pipeline.tokenizer_2 text_encoder_1: CLIPTextModel = pipeline.text_encoder text_encoder_2: CLIPTextModel = pipeline.text_encoder_2 vae: AutoencoderKL = pipeline.vae unet: UNet2DConditionModel = pipeline.unet noise_scheduler = DDPMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False, steps_offset=1, ) if vae_model is not None: vae: AutoencoderKL = AutoencoderKL.from_pretrained(vae_model) # Disable gradient calculation for model weights to save memory. text_encoder_1.requires_grad_(False) text_encoder_2.requires_grad_(False) vae.requires_grad_(False) unet.requires_grad_(False) if dtype is not None: text_encoder_1 = text_encoder_1.to(dtype=dtype) text_encoder_2 = text_encoder_2.to(dtype=dtype) vae = vae.to(dtype=dtype) unet = unet.to(dtype=dtype) # Put models in 'eval' mode. text_encoder_1.eval() text_encoder_2.eval() vae.eval() unet.eval() return tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, unet ================================================ FILE: src/invoke_training/_shared/stable_diffusion/textual_inversion.py ================================================ import logging import torch from accelerate import Accelerator from transformers import CLIPTextModel, CLIPTokenizer, PreTrainedTokenizer from invoke_training._shared.checkpoints.serialization import load_state_dict def _expand_placeholder_token(placeholder_token: str, num_vectors: int = 1) -> list[str]: """Expand a placeholder token into a list of placeholder tokens based on the number of embedding vectors being trained. """ placeholder_tokens = [placeholder_token] if num_vectors < 1: raise ValueError(f"num_vectors must be >1, but is '{num_vectors}'.") # Add dummy placeholder tokens if num_vectors > 1. for i in range(1, num_vectors): placeholder_tokens.append(f"{placeholder_token}_{i}") return placeholder_tokens def _add_tokens_to_tokenizer(placeholder_tokens: list[str], tokenizer: PreTrainedTokenizer): """Add new tokens to a tokenizer. Raises: ValueError: Raises if the tokenizer already contains one of the tokens in `placeholder_tokens`. """ num_added_tokens = tokenizer.add_tokens(placeholder_tokens) if num_added_tokens != len(placeholder_tokens): raise ValueError( f"The tokenizer already contains one of the tokens in '{placeholder_tokens}'. Please pass a different" " 'placeholder_token' that is not already in the tokenizer." ) def expand_placeholders_in_caption(caption: str, tokenizer: CLIPTokenizer) -> str: """Expand any multi-vector placeholder tokens in the caption. For example, "a dog in the style of my_placeholder", could get expanded to "a dog in the style of my_placeholder my_placeholder_1 my_placeholder_2". This implementation is based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders/textual_inversion.py#L144. This logic gets applied automatically when running a full diffusers text-to-image pipeline. """ tokens = tokenizer.tokenize(caption) unique_tokens = set(tokens) for token in unique_tokens: if token in tokenizer.added_tokens_encoder: replacement = token i = 1 while f"{token}_{i}" in tokenizer.added_tokens_encoder: replacement += f" {token}_{i}" i += 1 if replacement != token: # If the replacement is different from the original token, then we double check that the replacement # isn't already in the caption. If the replacement is already in the caption, this probably means that # someone didn't realize that placeholder expansion is handled here. assert replacement not in caption caption = caption.replace(token, replacement) return caption def initialize_placeholder_tokens_from_initializer_token( tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, initializer_token: str, placeholder_token: str, num_vectors: int, logger: logging.Logger, ) -> tuple[list[str], list[int]]: # Convert the initializer_token to a token id. initializer_token_ids = tokenizer.encode(initializer_token, add_special_tokens=False) if len(initializer_token_ids) > 1: logger.warning( f"The initializer_token '{initializer_token}' gets tokenized to {len(initializer_token_ids)} tokens. " "Only the first token will be used. It is recommended to choose a different initializer_token that maps to " "a single token." ) initializer_token_id = initializer_token_ids[0] # Expand the tokenizer / text_encoder to include the placeholder tokens. placeholder_tokens = _expand_placeholder_token(placeholder_token, num_vectors=num_vectors) _add_tokens_to_tokenizer(placeholder_tokens, tokenizer) # Resize the token embeddings as we have added new special tokens to the tokenizer. text_encoder.resize_token_embeddings(len(tokenizer)) placeholder_token_ids = tokenizer.convert_tokens_to_ids(placeholder_tokens) # convert_tokens_to_ids returns a `int | list[int]` type, but since we pass in a list it should always return a # list. assert isinstance(placeholder_token_ids, list) # Initialize the newly-added placeholder token(s) with the embeddings of the initializer token. token_embeds = text_encoder.get_input_embeddings().weight.data with torch.no_grad(): for placeholder_token_id in placeholder_token_ids: token_embeds[placeholder_token_id] = token_embeds[initializer_token_id].clone() return placeholder_tokens, placeholder_token_ids def initialize_placeholder_tokens_from_initial_phrase( tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, initial_phrase: str, placeholder_token: str ) -> tuple[list[str], list[int]]: # Convert the initial_phrase to token ids. initial_token_ids = tokenizer.encode(initial_phrase, add_special_tokens=False) # Expand the tokenizer / text_encoder to include one placeholder token for each token in the initial_phrase. placeholder_tokens = _expand_placeholder_token(placeholder_token, num_vectors=len(initial_token_ids)) _add_tokens_to_tokenizer(placeholder_tokens, tokenizer) # Resize the token embeddings as we have added new special tokens to the tokenizer. text_encoder.resize_token_embeddings(len(tokenizer)) placeholder_token_ids = tokenizer.convert_tokens_to_ids(placeholder_tokens) # convert_tokens_to_ids returns a `int | list[int]` type, but since we pass in a list it should always return a # list. assert isinstance(placeholder_token_ids, list) # Initialize the newly-added placeholder token(s) with the embeddings of the initial phrase. token_embeds = text_encoder.get_input_embeddings().weight.data with torch.no_grad(): for placeholder_token_id, initial_token_id in zip(placeholder_token_ids, initial_token_ids): token_embeds[placeholder_token_id] = token_embeds[initial_token_id].clone() return placeholder_tokens, placeholder_token_ids def initialize_placeholder_tokens_from_initial_embedding( tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, initial_embedding_file: str, placeholder_token: str, num_vectors: int, ) -> tuple[list[str], list[int]]: # Expand the tokenizer / text_encoder to include the placeholder tokens. placeholder_tokens = _expand_placeholder_token(placeholder_token, num_vectors=num_vectors) _add_tokens_to_tokenizer(placeholder_tokens, tokenizer) # Resize the token embeddings as we have added new special tokens to the tokenizer. text_encoder.resize_token_embeddings(len(tokenizer)) state_dict = load_state_dict(initial_embedding_file) if placeholder_token not in state_dict: raise ValueError( f"The initial embedding at '{initial_embedding_file}' does not contain an embedding for placeholder token " f"'{placeholder_token}'." ) embeddings = state_dict[placeholder_token] if embeddings.shape[0] != len(placeholder_tokens): raise ValueError( f"The number of initial embeddings in '{initial_embedding_file}' ({embeddings.shape[0]}) does not match " f"the expected number of placeholder tokens ({len(placeholder_tokens)})." ) placeholder_token_ids = tokenizer.convert_tokens_to_ids(placeholder_tokens) # convert_tokens_to_ids returns a `int | list[int]` type, but since we pass in a list it should always return a # list. assert isinstance(placeholder_token_ids, list) # Initialize the newly-added placeholder token(s) with the loaded embeddings. token_embeds = text_encoder.get_input_embeddings().weight.data with torch.no_grad(): for i, token_id in enumerate(placeholder_token_ids): token_embeds[token_id] = embeddings[i].clone() return placeholder_tokens, placeholder_token_ids def restore_original_embeddings( tokenizer: CLIPTokenizer, placeholder_token_ids: list[int], accelerator: Accelerator, text_encoder: CLIPTextModel, orig_embeds_params: torch.Tensor, ): """Restore the text_encoder embeddings that we are not actively training to make sure they don't change. TODO(ryand): Look into whether this is actually necessary if we set requires_grad correctly. """ index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool) index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False index_updates = ~index_no_updates with torch.no_grad(): unwrapped_text_encoder = accelerator.unwrap_model(text_encoder) unwrapped_text_encoder.get_input_embeddings().weight[index_no_updates] = orig_embeds_params[index_no_updates] target_std = unwrapped_text_encoder.get_input_embeddings().weight[index_no_updates].std() new_embeddings = unwrapped_text_encoder.get_input_embeddings().weight[index_updates] target_over_new_std = target_std / new_embeddings.std() # Scale the new embeddings towards the target embeddings. Raise to the 0.1 power to avoid large changes. new_embeddings = new_embeddings * (target_over_new_std**0.1) unwrapped_text_encoder.get_input_embeddings().weight[index_updates] = new_embeddings ================================================ FILE: src/invoke_training/_shared/stable_diffusion/tokenize_captions.py ================================================ import torch from transformers import CLIPTokenizer from invoke_training._shared.stable_diffusion.textual_inversion import expand_placeholders_in_caption def tokenize_captions(tokenizer: CLIPTokenizer, captions: list[str]) -> torch.Tensor: """Tokenize a list of caption. Args: tokenizer (CLIPTokenizer): The tokenizer. caption (str): The caption. Returns: torch.Tensor: The token IDs. """ caption_token_ids = [] for caption in captions: caption = expand_placeholders_in_caption(caption, tokenizer) input = tokenizer( caption, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt", ) caption_token_ids.append(input.input_ids[0, ...]) caption_token_ids = torch.stack(caption_token_ids) return caption_token_ids ================================================ FILE: src/invoke_training/_shared/stable_diffusion/validation.py ================================================ import logging import os import numpy as np import torch import torch.utils.data from accelerate import Accelerator from accelerate.hooks import remove_hook_from_module from diffusers import ( AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, StableDiffusionXLPipeline, UNet2DConditionModel, ) from transformers import CLIPTextModel, CLIPTokenizer from invoke_training._shared.data.utils.resolution import Resolution from invoke_training.pipelines.callbacks import PipelineCallbacks, ValidationImage, ValidationImages from invoke_training.pipelines.stable_diffusion.lora.config import SdLoraConfig from invoke_training.pipelines.stable_diffusion_xl.lora.config import SdxlLoraConfig def generate_validation_images_sd( # noqa: C901 epoch: int, step: int, out_dir: str, accelerator: Accelerator, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, noise_scheduler: DDPMScheduler, unet: UNet2DConditionModel, config: SdLoraConfig, logger: logging.Logger, callbacks: list[PipelineCallbacks] | None = None, ): """Generate validation images for the purpose of tracking image generation behaviour on fixed prompts throughout training. """ # Record original model devices so that we can restore this state after running the pipeline with CPU model # offloading. unet_device = unet.device vae_device = vae.device text_encoder_device = text_encoder.device # Create pipeline. pipeline = StableDiffusionPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=noise_scheduler, safety_checker=None, feature_extractor=None, # TODO(ryand): Add safety checker support. requires_safety_checker=False, ) if config.enable_cpu_offload_during_validation: pipeline.enable_model_cpu_offload(accelerator.device.index or 0) else: pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) validation_resolution = Resolution.parse(config.data_loader.resolution) validation_images = ValidationImages(images=[], epoch=epoch, step=step) validation_step_dir = os.path.join(out_dir, "validation", f"epoch_{epoch:0>8}-step_{step:0>8}") logger.info(f"Generating validation images ({validation_step_dir}).") # Run inference. with torch.no_grad(): for prompt_idx in range(len(config.validation_prompts)): positive_prompt = config.validation_prompts[prompt_idx] negative_prompt = None if config.negative_validation_prompts is not None: negative_prompt = config.negative_validation_prompts[prompt_idx] logger.info(f"Validation prompt {prompt_idx}, pos: '{positive_prompt}', neg: '{negative_prompt or ''}'") generator = torch.Generator(device=accelerator.device) if config.seed is not None: generator = generator.manual_seed(config.seed) images = [] for _ in range(config.num_validation_images_per_prompt): with accelerator.autocast(): images.append( pipeline( positive_prompt, num_inference_steps=30, generator=generator, height=validation_resolution.height, width=validation_resolution.width, negative_prompt=negative_prompt, ).images[0] ) # Save images to disk. validation_prompt_dir = os.path.join(validation_step_dir, f"prompt_{prompt_idx:0>4}") os.makedirs(validation_prompt_dir) for image_idx, image in enumerate(images): image_path = os.path.join(validation_prompt_dir, f"{image_idx:0>4}.jpg") validation_images.images.append( ValidationImage(file_path=image_path, prompt=positive_prompt, image_idx=image_idx) ) image.save(image_path) # Log images to trackers. Currently, only tensorboard is supported. for tracker in accelerator.trackers: if tracker.name == "tensorboard": np_images = np.stack([np.asarray(img) for img in images]) tracker.writer.add_images( f"validation (prompt {prompt_idx})", np_images, step, dataformats="NHWC", ) del pipeline torch.cuda.empty_cache() # Remove hooks from models. # HACK(ryand): Hooks get added when calling `pipeline.enable_model_cpu_offload(...)`, but `StableDiffusionPipeline` # does not offer a way to clean them up so we have to do this manually. for model in [unet, vae, text_encoder]: remove_hook_from_module(model) # Restore models to original devices. unet.to(unet_device) vae.to(vae_device) text_encoder.to(text_encoder_device) # Run callbacks. if callbacks is not None: for cb in callbacks: cb.on_save_validation_images(images=validation_images) def generate_validation_images_sdxl( # noqa: C901 epoch: int, step: int, out_dir: str, accelerator: Accelerator, vae: AutoencoderKL, text_encoder_1: CLIPTextModel, text_encoder_2: CLIPTextModel, tokenizer_1: CLIPTokenizer, tokenizer_2: CLIPTokenizer, noise_scheduler: DDPMScheduler, unet: UNet2DConditionModel, config: SdxlLoraConfig, logger: logging.Logger, callbacks: list[PipelineCallbacks] | None = None, ): """Generate validation images for the purpose of tracking image generation behaviour on fixed prompts throughout training. """ # Record original model devices so that we can restore this state after running the pipeline with CPU model # offloading. unet_device = unet.device vae_device = vae.device text_encoder_1_device = text_encoder_1.device text_encoder_2_device = text_encoder_2.device # Create pipeline. pipeline = StableDiffusionXLPipeline( vae=vae, text_encoder=text_encoder_1, text_encoder_2=text_encoder_2, tokenizer=tokenizer_1, tokenizer_2=tokenizer_2, unet=unet, scheduler=noise_scheduler, ) if config.enable_cpu_offload_during_validation: pipeline.enable_model_cpu_offload(accelerator.device.index or 0) else: pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) validation_resolution = Resolution.parse(config.data_loader.resolution) validation_images = ValidationImages(images=[], epoch=epoch, step=step) validation_step_dir = os.path.join(out_dir, "validation", f"epoch_{epoch:0>8}-step_{step:0>8}") logger.info(f"Generating validation images ({validation_step_dir}).") # Run inference. with torch.no_grad(): for prompt_idx in range(len(config.validation_prompts)): positive_prompt = config.validation_prompts[prompt_idx] negative_prompt = None if config.negative_validation_prompts is not None: negative_prompt = config.negative_validation_prompts[prompt_idx] logger.info(f"Validation prompt {prompt_idx}, pos: '{positive_prompt}', neg: '{negative_prompt or ''}'") generator = torch.Generator(device=accelerator.device) if config.seed is not None: generator = generator.manual_seed(config.seed) images = [] for _ in range(config.num_validation_images_per_prompt): with accelerator.autocast(): images.append( pipeline( positive_prompt, num_inference_steps=30, generator=generator, height=validation_resolution.height, width=validation_resolution.width, negative_prompt=negative_prompt, ).images[0] ) # Save images to disk. validation_prompt_dir = os.path.join(validation_step_dir, f"prompt_{prompt_idx:0>4}") os.makedirs(validation_prompt_dir) for image_idx, image in enumerate(images): image_path = os.path.join(validation_prompt_dir, f"{image_idx:0>4}.jpg") validation_images.images.append( ValidationImage(file_path=image_path, prompt=positive_prompt, image_idx=image_idx) ) image.save(image_path) # Log images to trackers. Currently, only tensorboard is supported. for tracker in accelerator.trackers: if tracker.name == "tensorboard": np_images = np.stack([np.asarray(img) for img in images]) tracker.writer.add_images( f"validation (prompt {prompt_idx})", np_images, step, dataformats="NHWC", ) del pipeline torch.cuda.empty_cache() # Remove hooks from models. # HACK(ryand): Hooks get added when calling `pipeline.enable_model_cpu_offload(...)`, but # `StableDiffusionXLPipeline` does not offer a way to clean them up so we have to do this manually. for model in [unet, vae, text_encoder_1, text_encoder_2]: remove_hook_from_module(model) # Restore models to original devices. unet.to(unet_device) vae.to(vae_device) text_encoder_1.to(text_encoder_1_device) text_encoder_2.to(text_encoder_2_device) # Run callbacks. if callbacks is not None: for cb in callbacks: cb.on_save_validation_images(images=validation_images) ================================================ FILE: src/invoke_training/_shared/tools/__init__.py ================================================ ================================================ FILE: src/invoke_training/_shared/tools/generate_images.py ================================================ import os from pathlib import Path from typing import Optional import torch from tqdm import tqdm from invoke_training._shared.data.datasets.image_pair_preference_dataset import ImagePairPreferenceDataset from invoke_training._shared.stable_diffusion.model_loading_utils import ( PipelineVersionEnum, load_pipeline, ) def generate_images( out_dir: str, model: str, hf_variant: str | None, pipeline_version: PipelineVersionEnum, prompts: list[str], set_size: int, num_sets: int, height: int, width: int, loras: Optional[list[tuple[Path, float]]] = None, ti_embeddings: Optional[list[str]] = None, seed: int = 0, torch_dtype: torch.dtype = torch.float16, torch_device: str = "cuda", enable_cpu_offload: bool = False, ): """Generate a set of images and store them in a directory. Typically used to generate a datasets for prior preservation / regularization. Args: out_dir (str): The output directory to create. model (str): The name or path of the diffusers pipeline to generate with. sd_version (PipelineVersionEnum): The model version. prompt (str): The prompt to generate images with. set_size (int): The number of images in a 'set' for a given prompt. num_sets (int): The number of 'sets' to generate for each prompt. height (int): The output image height in pixels (recommended to match the resolution that the model was trained with). width (int): The output image width in pixels (recommended to match the resolution that the model was trained with). loras (list[tuple[Path, float]], optional): Paths to LoRA models to apply to the base model with associated weights. ti_embeddings (list[str], optional): Paths to TI embeddings to apply to the base model. seed (int, optional): A seed for repeatability. Defaults to 0. torch_dtype (torch.dtype, optional): The torch dtype. Defaults to torch.float16. torch_device (str, optional): The torch device. Defaults to "cuda". enable_cpu_offload (bool, optional): If True, models will be loaded onto the GPU one by one to conserve VRAM. Defaults to False. """ pipeline = load_pipeline(model_name_or_path=model, pipeline_version=pipeline_version, variant=hf_variant) loras = loras or [] for lora in loras: lora_path, lora_scale = lora pipeline.load_lora_weights(str(lora_path), weight_name=str(lora_path.name)) pipeline.fuse_lora(lora_scale=lora_scale) ti_embeddings = ti_embeddings or [] for ti_embedding in ti_embeddings: pipeline.load_textual_inversion(ti_embedding) pipeline.to(torch_dtype=torch_dtype) if enable_cpu_offload: pipeline.enable_model_cpu_offload() else: pipeline.to(torch_device=torch_device) pipeline.set_progress_bar_config(disable=True) generator = torch.Generator(device=torch_device) if seed is not None: generator = generator.manual_seed(seed) os.makedirs(out_dir) metadata = [] total_images = num_sets * len(prompts) * set_size with torch.no_grad(), tqdm(total=total_images) as pbar: for prompt_idx in range(len(prompts)): for set_idx in range(num_sets): set_dir = os.path.join(out_dir, f"prompt-{prompt_idx:0>4}", f"set-{set_idx:0>4}") os.makedirs(set_dir) set_metadata_dict = {"prompt": prompts[prompt_idx]} for image_idx in range(set_size): image = pipeline( prompts[prompt_idx], num_inference_steps=30, generator=generator, height=height, width=width, ).images[0] image_path = os.path.join(set_dir, f"image-{image_idx}.jpg") image.save(image_path) set_metadata_dict[f"image_{image_idx}"] = os.path.relpath(image_path, start=out_dir) set_metadata_dict[f"prefer_{image_idx}"] = False pbar.update(1) metadata.append(set_metadata_dict) ImagePairPreferenceDataset.save_metadata(metadata=metadata, dataset_dir=out_dir) ================================================ FILE: src/invoke_training/_shared/utils/import_xformers.py ================================================ def import_xformers(): try: import xformers # noqa: F401 except ImportError: raise ImportError( "xformers is not installed. Either set `xformers = False` in your training config, or install it using " '`pip install ".[xformers]"`.' ) ================================================ FILE: src/invoke_training/_shared/utils/jsonl.py ================================================ import json from pathlib import Path from typing import Any def load_jsonl(jsonl_path: Path | str) -> list[Any]: """Load a JSONL file.""" data = [] with open(jsonl_path) as f: while (line := f.readline().strip()) != "": data.append(json.loads(line)) return data def save_jsonl(data: list[Any], jsonl_path: Path | str) -> None: """Save a list of objects to a JSONL file.""" with open(jsonl_path, "w") as f: for line in data: f.write(json.dumps(line) + "\n") ================================================ FILE: src/invoke_training/config/__init__.py ================================================ ================================================ FILE: src/invoke_training/config/base_pipeline_config.py ================================================ import typing from typing import Optional from invoke_training.config.config_base_model import ConfigBaseModel class BasePipelineConfig(ConfigBaseModel): """A base config with fields that should be inherited by all pipelines.""" type: str seed: Optional[int] = None """A randomization seed for reproducible training. Set to any constant integer for consistent training results. If set to `null`, training will be non-deterministic. """ base_output_dir: str """The output directory where the training outputs (model checkpoints, logs, intermediate predictions) will be written. A subdirectory will be created with a timestamp for each new training run. """ report_to: typing.Literal["all", "tensorboard", "wandb", "comet_ml"] = "tensorboard" """The integration to report results and logs to. This value is passed to Hugging Face Accelerate. See `accelerate.Accelerator.log_with` for more details. """ max_train_steps: int | None = None """Total number of training steps to perform. One training step is one gradient update. One of `max_train_steps` or `max_train_epochs` should be set. """ max_train_epochs: int | None = None """Total number of training epochs to perform. One epoch is one pass over the entire dataset. One of `max_train_steps` or `max_train_epochs` should be set. """ save_every_n_epochs: int | None = None """The interval (in epochs) at which to save checkpoints. One of `save_every_n_epochs` or `save_every_n_steps` should be set. """ save_every_n_steps: int | None = None """The interval (in steps) at which to save checkpoints. One of `save_every_n_epochs` or `save_every_n_steps` should be set. """ validate_every_n_epochs: int | None = None """The interval (in epochs) at which validation images will be generated. One of `validate_every_n_epochs` or `validate_every_n_steps` should be set. """ validate_every_n_steps: int | None = None """The interval (in steps) at which validation images will be generated. One of `validate_every_n_epochs` or `validate_every_n_steps` should be set. """ ================================================ FILE: src/invoke_training/config/config_base_model.py ================================================ from pydantic import BaseModel, ConfigDict class ConfigBaseModel(BaseModel): """Base model for all invoke training configuration models.""" # Configure to raise if extra fields are passed in. model_config = ConfigDict(extra="forbid") ================================================ FILE: src/invoke_training/config/data/__init__.py ================================================ ================================================ FILE: src/invoke_training/config/data/data_loader_config.py ================================================ from typing import Literal, Optional from invoke_training.config.config_base_model import ConfigBaseModel from invoke_training.config.data.dataset_config import ( ImageCaptionDatasetConfig, ImageDirDatasetConfig, ) class AspectRatioBucketConfig(ConfigBaseModel): target_resolution: int """The target resolution for all aspect ratios. When generating aspect ratio buckets, the resolution of each bucket is selected to have roughly `target_resolution * target_resolution` pixels (i.e. a square image with dimensions equal to `target_resolution`). """ start_dim: int """Aspect ratio bucket resolutions are generated as follows: - Iterate over 'first' dimension values from `start_dim` to `end_dim` in steps of size `divisible_by`. - Calculate the 'second' dimension to be as close as possible to the total number of pixels in `target_resolution`, while still being divisible by `divisible_by`. tip: Choosing aspect ratio buckets The aspect ratio bucket resolutions are logged at the start of training with the number of images in each bucket. Review these logs to make sure that images are being split into buckets as expected. Highly fragmented splits (i.e. many buckets with few examples in each) can 1) limit the extent to which examples can be shuffled, and 2) slow down training if there are many partial batches. """ end_dim: int """See explanation under [`start_dim`][invoke_training.config.data.data_loader_config.AspectRatioBucketConfig.start_dim]. """ divisible_by: int """See explanation under [`start_dim`][invoke_training.config.data.data_loader_config.AspectRatioBucketConfig.start_dim]. """ class ImageCaptionSDDataLoaderConfig(ConfigBaseModel): type: Literal["IMAGE_CAPTION_SD_DATA_LOADER"] = "IMAGE_CAPTION_SD_DATA_LOADER" dataset: ImageCaptionDatasetConfig aspect_ratio_buckets: AspectRatioBucketConfig | None = None resolution: int | tuple[int, int] = 512 """The resolution for input images. Either a scalar integer representing the square resolution height and width, or a (height, width) tuple. All of the images in the dataset will be resized to this resolution unless the `aspect_ratio_buckets` config is set. """ center_crop: bool = True """If True, input images will be center-cropped to the target resolution. If False, input images will be randomly cropped to the target resolution. """ random_flip: bool = False """Whether random flip augmentations should be applied to input images. """ caption_prefix: str | None = None """A prefix that will be prepended to all captions. If None, no prefix will be added. """ dataloader_num_workers: int = 0 """Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process. """ class ImageCaptionFluxDataLoaderConfig(ConfigBaseModel): type: Literal["IMAGE_CAPTION_FLUX_DATA_LOADER"] = "IMAGE_CAPTION_FLUX_DATA_LOADER" dataset: ImageCaptionDatasetConfig aspect_ratio_buckets: AspectRatioBucketConfig | None = None resolution: int | tuple[int, int] = 512 """The resolution for input images. Either a scalar integer representing the square resolution height and width, or a (height, width) tuple. All of the images in the dataset will be resized to this resolution unless the `aspect_ratio_buckets` config is set. """ center_crop: bool = True """If True, input images will be center-cropped to the target resolution. If False, input images will be randomly cropped to the target resolution. """ random_flip: bool = False """Whether random flip augmentations should be applied to input images. """ caption_prefix: str | None = None """A prefix that will be prepended to all captions. If None, no prefix will be added. """ dataloader_num_workers: int = 0 """Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process. """ class DreamboothSDDataLoaderConfig(ConfigBaseModel): type: Literal["DREAMBOOTH_SD_DATA_LOADER"] = "DREAMBOOTH_SD_DATA_LOADER" instance_caption: str class_caption: Optional[str] = None instance_dataset: ImageDirDatasetConfig class_dataset: Optional[ImageDirDatasetConfig] = None class_data_loss_weight: float = 1.0 """The loss weight applied to class dataset examples. Instance dataset examples have an implicit loss weight of 1.0. """ aspect_ratio_buckets: AspectRatioBucketConfig | None = None """The aspect ratio bucketing configuration. If None, aspect ratio bucketing is disabled, and all images will be resized to the same resolution. """ resolution: int | tuple[int, int] = 512 """The resolution for input images. Either a scalar integer representing the square resolution height and width, or a (height, width) tuple. All of the images in the dataset will be resized to this resolution unless the `aspect_ratio_buckets` config is set. """ center_crop: bool = True """If True, input images will be center-cropped to the target resolution. If False, input images will be randomly cropped to the target resolution. """ random_flip: bool = False """Whether random flip augmentations should be applied to input images. """ dataloader_num_workers: int = 0 """Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process. """ class TextualInversionSDDataLoaderConfig(ConfigBaseModel): type: Literal["TEXTUAL_INVERSION_SD_DATA_LOADER"] = "TEXTUAL_INVERSION_SD_DATA_LOADER" dataset: ImageDirDatasetConfig | ImageCaptionDatasetConfig caption_preset: Literal["style", "object"] | None = None caption_templates: list[str] | None = None """A list of caption templates with a single template argument 'slot' in each. E.g.: - "a photo of a {}" - "a rendering of a {}" - "a cropped photo of the {}" """ keep_original_captions: bool = False """If `True`, then the captions generated as a result of the `caption_preset` or `caption_templates` will be used as prefixes for the original captions. If `False`, then the generated captions will replace the original captions. """ aspect_ratio_buckets: AspectRatioBucketConfig | None = None """The aspect ratio bucketing configuration. If None, aspect ratio bucketing is disabled, and all images will be resized to the same resolution. """ resolution: int | tuple[int, int] = 512 """The resolution for input images. Either a scalar integer representing the square resolution height and width, or a (height, width) tuple. All of the images in the dataset will be resized to this resolution unless the `aspect_ratio_buckets` config is set. """ center_crop: bool = True """If True, input images will be center-cropped to the target resolution. If False, input images will be randomly cropped to the target resolution. """ random_flip: bool = False """Whether random flip augmentations should be applied to input images. """ shuffle_caption_delimiter: str | None = None """If `None`, then no caption shuffling is applied. If set, then captions are split on this delimiter and shuffled. """ dataloader_num_workers: int = 0 """Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process. """ ================================================ FILE: src/invoke_training/config/data/dataset_config.py ================================================ from typing import Annotated, Literal, Optional, Union from pydantic import Field from invoke_training.config.config_base_model import ConfigBaseModel class HFHubImageCaptionDatasetConfig(ConfigBaseModel): type: Literal["HF_HUB_IMAGE_CAPTION_DATASET"] = "HF_HUB_IMAGE_CAPTION_DATASET" dataset_name: str """The name of a Hugging Face dataset. """ dataset_config_name: Optional[str] = None """The Hugging Face dataset config name. Leave as None if there's only one config. """ hf_cache_dir: Optional[str] = None """The Hugging Face cache directory to use for dataset downloads. If None, the default value will be used (usually '~/.cache/huggingface/datasets'). """ image_column: str = "image" """The name of the dataset column that contains image paths. """ caption_column: str = "text" """The name of the dataset column that contains captions. """ class ImageCaptionJsonlDatasetConfig(ConfigBaseModel): type: Literal["IMAGE_CAPTION_JSONL_DATASET"] = "IMAGE_CAPTION_JSONL_DATASET" jsonl_path: str """The path to a JSONL file containing image paths and captions.""" image_column: str = "image" """The name of the dataset column that contains image paths. """ caption_column: str = "text" """The name of the dataset column that contains captions. """ keep_in_memory: bool = False """If `True`, load all images into memory on initialization so that they can be accessed quickly. If `False`, images are loaded from disk each time they are accessed. Setting to `True` improves performance for datasets that are small enough to be kept in memory. """ class ImageDirDatasetConfig(ConfigBaseModel): type: Literal["IMAGE_DIR_DATASET"] = "IMAGE_DIR_DATASET" dataset_dir: str """The directory to load images from.""" keep_in_memory: bool = False """If `True`, load all images into memory on initialization so that they can be accessed quickly. If `False`, images are loaded from disk each time they are accessed. Setting to `True` improves performance for datasets that are small enough to be kept in memory. """ class ImageCaptionDirDatasetConfig(ConfigBaseModel): type: Literal["IMAGE_CAPTION_DIR_DATASET"] = "IMAGE_CAPTION_DIR_DATASET" dataset_dir: str """The directory to load images from.""" keep_in_memory: bool = False """If `True`, load all images into memory on initialization so that they can be accessed quickly. If `False`, images are loaded from disk each time they are accessed. Setting to `True` improves performance for datasets that are small enough to be kept in memory. """ # Datasets that produce image-caption pairs. ImageCaptionDatasetConfig = Annotated[ Union[HFHubImageCaptionDatasetConfig, ImageCaptionJsonlDatasetConfig, ImageCaptionDirDatasetConfig], Field(discriminator="type"), ] ================================================ FILE: src/invoke_training/config/optimizer/__init__.py ================================================ ================================================ FILE: src/invoke_training/config/optimizer/optimizer_config.py ================================================ import typing from invoke_training.config.config_base_model import ConfigBaseModel class AdamOptimizerConfig(ConfigBaseModel): optimizer_type: typing.Literal["AdamW"] = "AdamW" learning_rate: float = 1e-4 """Initial learning rate to use (after the potential warmup period). Note that in some training pipelines this can be overriden for a specific group of params: https://pytorch.org/docs/stable/optim.html#per-parameter-options (E.g. see `text_encoder_learning_rate` and `unet_learning_rate`) """ beta1: float = 0.9 beta2: float = 0.999 weight_decay: float = 1e-2 epsilon: float = 1e-8 use_8bit: bool = False """Use an 8-bit version of the Adam optimizer. This requires the bitsandbytes library to be installed. use_8bit reduces the VRAM usage of the optimizer, but increases the risk of issues with numerical stability. """ class ProdigyOptimizerConfig(ConfigBaseModel): optimizer_type: typing.Literal["Prodigy"] = "Prodigy" learning_rate: float = 1.0 """The learning rate. For the Prodigy optimizer, the learning rate is adjusted dynamically. A value of 1.0 is recommended. Note that in some training pipelines this can be overriden for a specific group of params: https://pytorch.org/docs/stable/optim.html#per-parameter-options (E.g. see `text_encoder_learning_rate` and `unet_learning_rate`) """ weight_decay: float = 0.0 use_bias_correction: bool = False safeguard_warmup: bool = False ================================================ FILE: src/invoke_training/config/pipeline_config.py ================================================ from typing import Annotated, Union from pydantic import Field from invoke_training.pipelines._experimental.sd_dpo_lora.config import SdDirectPreferenceOptimizationLoraConfig from invoke_training.pipelines.flux.lora.config import FluxLoraConfig from invoke_training.pipelines.stable_diffusion.lora.config import SdLoraConfig from invoke_training.pipelines.stable_diffusion.textual_inversion.config import SdTextualInversionConfig from invoke_training.pipelines.stable_diffusion_xl.finetune.config import SdxlFinetuneConfig from invoke_training.pipelines.stable_diffusion_xl.lora.config import SdxlLoraConfig from invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config import ( SdxlLoraAndTextualInversionConfig, ) from invoke_training.pipelines.stable_diffusion_xl.textual_inversion.config import SdxlTextualInversionConfig PipelineConfig = Annotated[ Union[ FluxLoraConfig, SdLoraConfig, SdxlLoraConfig, SdTextualInversionConfig, SdxlTextualInversionConfig, SdxlLoraAndTextualInversionConfig, SdxlFinetuneConfig, SdDirectPreferenceOptimizationLoraConfig, ], Field(discriminator="type"), ] ================================================ FILE: src/invoke_training/model_merge/__init__.py ================================================ ================================================ FILE: src/invoke_training/model_merge/extract_lora.py ================================================ import torch import tqdm from peft.peft_model import PeftModel # All original base model weights in a PeftModel have this prefix and suffix. PEFT_BASE_LAYER_PREFIX = "base_model.model." PEFT_BASE_LAYER_SUFFIX = ".base_layer.weight" def get_patched_base_weights_from_peft_model(peft_model: PeftModel) -> dict[str, torch.Tensor]: """Get a state_dict containing the base model weights *thath are patched* in the provided PeftModel. I.e. only return base model weights that have associated LoRa layers, but don't return the LoRA layers. """ state_dict = peft_model.state_dict() out_state_dict: dict[str, torch.Tensor] = {} for weight_name in state_dict: # Weights that end with ".base_layer.weight" are the original weights for LoRA layers. if weight_name.endswith(PEFT_BASE_LAYER_SUFFIX): # Extract the base module name. module_name = weight_name[: -len(PEFT_BASE_LAYER_SUFFIX)] assert module_name.startswith(PEFT_BASE_LAYER_PREFIX) module_name = module_name[len(PEFT_BASE_LAYER_PREFIX) :] out_state_dict[module_name] = state_dict[weight_name] return out_state_dict def get_state_dict_diff( state_dict_1: dict[str, torch.Tensor], state_dict_2: dict[str, torch.Tensor] ) -> dict[str, torch.Tensor]: """Return the difference between two state_dicts: state_dict_1 - state_dict_2.""" return {key: state_dict_1[key] - state_dict_2[key] for key in state_dict_1} @torch.no_grad() def extract_lora_from_diffs( diffs: dict[str, torch.Tensor], rank: int, clamp_quantile: float, out_dtype: torch.dtype ) -> dict[str, tuple[torch.Tensor, torch.Tensor]]: lora_weights = {} for lora_name, mat in tqdm.tqdm(list(diffs.items())): # Use full precision for the intermediate calculations. mat = mat.to(torch.float32) is_conv2d = False if len(mat.shape) == 4: # Conv2D is_conv2d = True out_dim, in_dim, kernel_h, kernel_w = mat.shape # Reshape to (out_dim, in_dim * kernel_h * kernel_w). mat = mat.flatten(start_dim=1) elif len(mat.shape) == 2: # Linear out_dim, in_dim = mat.shape else: raise ValueError(f"Unexpected weight shape: {mat.shape}") # LoRA rank cannot exceed the original dimensions. assert rank < in_dim assert rank < out_dim u: torch.Tensor s: torch.Tensor v_h: torch.Tensor u, s, v_h = torch.linalg.svd(mat) # Apply the Eckart-Young-Mirsky theorem. # https://en.wikipedia.org/wiki/Low-rank_approximation#Proof_of_Eckart%E2%80%93Young%E2%80%93Mirsky_theorem_(for_Frobenius_norm) u = u[:, :rank] s = s[:rank] u = u @ torch.diag(s) v_h = v_h[:rank, :] # At this point, u is the lora_up (a.k.a. lora_B) weight, and v_h is the lora_down (a.k.a. lora_A) weight. # The reason we don't use more appropriate variable names is to keep memory usage low - we want the old tensors # to get cleaned up after each operation. # Clamp the outliers. dist = torch.cat([u.flatten(), v_h.flatten()]) hi_val = torch.quantile(dist, clamp_quantile) low_val = -hi_val u = u.clamp(low_val, hi_val) v_h = v_h.clamp(low_val, hi_val) if is_conv2d: u = u.reshape(out_dim, rank, 1, 1) v_h = v_h.reshape(rank, in_dim, kernel_h, kernel_w) u = u.to(dtype=out_dtype).contiguous() v_h = v_h.to(dtype=out_dtype).contiguous() lora_weights[lora_name] = (u, v_h) return lora_weights ================================================ FILE: src/invoke_training/model_merge/merge_models.py ================================================ from typing import Literal import torch import tqdm from invoke_training.model_merge.utils.normalize_weights import normalize_weights @torch.no_grad() def merge_models( state_dicts: list[dict[str, torch.Tensor]], weights: list[float], merge_method: Literal["LERP", "SLERP"] = "LERP" ): """Merge multiple models into a single model. Args: state_dicts (list[dict[str, torch.Tensor]]): The state dicts to merge. weights (list[float]): The weights for each state dict. The weights will be normalized to sum to 1. merge_method (Literal["LERP", "SLERP"]): Merge method to use. Options: - "LERP": Linear interpolation a.k.a. weighted sum. - "SLERP": Spherical linear interpolation. """ if len(state_dicts) < 2: raise ValueError("Must provide >=2 models to merge.") if len(state_dicts) != len(weights): raise ValueError("Must provide a weight for each model.") if merge_method == "LERP": merge_fn = lerp elif merge_method == "SLERP": merge_fn = slerp else: raise ValueError(f"Unknown merge method: {merge_method}") normalized_weights = normalize_weights(weights) out_state_dict: dict[str, torch.Tensor] = state_dicts[0].copy() out_state_dict_weight = normalized_weights[0] for state_dict, normalized_weight in zip(state_dicts[1:], normalized_weights[1:], strict=True): if state_dict.keys() != out_state_dict.keys(): raise ValueError("State dicts must have the same keys.") cur_pair_weights = normalize_weights([out_state_dict_weight, normalized_weight]) for key in tqdm.tqdm(out_state_dict.keys()): out_state_dict[key] = merge_fn(out_state_dict[key], state_dict[key], cur_pair_weights[0]) # Update the weight of out_state_dict to be the sum of all state dicts merged so far. out_state_dict_weight += normalized_weight return out_state_dict def lerp(a: torch.Tensor, b: torch.Tensor, weight_a: float) -> torch.Tensor: """Linear interpolation.""" return torch.lerp(a, b, (1.0 - weight_a)) def slerp(a: torch.Tensor, b: torch.Tensor, weight_a: float, dot_product_thres=0.9995, epsilon=1e-10): """Spherical linear interpolation.""" # TODO(ryand): For multi-dimensional matrices, it might be better to apply slerp on a subset of the dimensions # (e.g. per-row), rather than treating the entire matrix as a single flattened vector. # Normalize the vectors. a_norm = torch.linalg.norm(a) b_norm = torch.linalg.norm(b) a_normalized = a / a_norm b_normalized = b / b_norm if a_norm < epsilon or b_norm < epsilon: # If either vector is very small, fallback to lerp to avoid weird effects. # TODO(ryand): Is fallback here necessary? return lerp(a, b, weight_a) # Dot product of the normalized vectors. # We are effectively treating multi-dimensional tensors as flattened vectors. dot_prod = torch.sum(a_normalized * b_normalized) # If the absolute value of the dot product is almost 1, the vectors are ~colinear, so use lerp. if torch.abs(dot_prod) > dot_product_thres: return lerp(a, b, weight_a) # Calculate initial angle between the vectors. theta_0 = torch.acos(dot_prod) # Angle at timestep t. t = 1.0 - weight_a theta_t = theta_0 * t sin_theta_0 = torch.sin(theta_0) sin_theta_t = torch.sin(theta_t) s0 = torch.sin(theta_0 - theta_t) / sin_theta_0 s1 = sin_theta_t / sin_theta_0 result = s0 * a + s1 * b return result ================================================ FILE: src/invoke_training/model_merge/merge_tasks_to_base.py ================================================ from typing import Literal import torch import tqdm from peft.utils.merge_utils import dare_linear, dare_ties, ties @torch.no_grad() def merge_tasks_to_base_model( base_state_dict: dict[str, torch.Tensor], task_state_dicts: list[dict[str, torch.Tensor]], task_weights: list[float], density: float = 0.2, merge_method: Literal["TIES", "DARE_LINEAR", "DARE_TIES"] = "TIES", ) -> torch.Tensor: """Merge a base model with one or more task-specific models. Args: base_state_dict (dict[str, torch.Tensor]): The base state dict to merge with. task_state_dicts (list[dict[str, torch.Tensor]]): A list of task-specific state dicts to merge into the base state dict. task_weights (list[float]): Weights for each task state dict. Weights of 1.0 for all task_state_dicts are recommended as a starting point (e.g. [1.0, 1.0, 1.0]). The weights can be adjusted from there (e.g. [1.0, 1.3, 1.0]). The weights are multipliers applied to the diff between each task_state_dict and the base model. density (float, optional): The fraction of values to preserve in the prune/trim step of DARE/TIES methods. Should be in the range [0, 1]. merge_method (Literal["TIES", "DARE_LINEAR", "DARE_TIES"], optional): The method to use for merging. Options: - "TIES": Use the TIES method (https://arxiv.org/pdf/2306.01708) - "DARE_LINEAR": Use the DARE method with linear interpolation (https://arxiv.org/pdf/2311.03099) - "DARE_TIES": Use the DARE method for pruning, and the TIES method for merging. """ if len(task_state_dicts) != len(task_weights): raise ValueError("Must provide a weight for each model.") task_weights = torch.tensor(task_weights) # Choose the merging method. if merge_method == "TIES": merge_fn = ties elif merge_method == "DARE_LINEAR": merge_fn = dare_linear elif merge_method == "DARE_TIES": merge_fn = dare_ties else: raise ValueError(f"Unknown merge method: {merge_method}") out_state_dict: dict[str, torch.Tensor] = {} for key in tqdm.tqdm(base_state_dict.keys()): base_tensor = base_state_dict[key] orig_dtype = base_tensor.dtype # Calculate the diff between each task tensor and the base tensor. task_diff_tensors = [state_dict[key] - base_tensor for state_dict in task_state_dicts] merged_diff_tensor = merge_fn( task_tensors=task_diff_tensors, weights=task_weights, density=density, ) # Some of the merge_fn implementations may return a tensor with a different dtype than the original tensors. # We cast the merged_diff_tensor back to the original dtype here. out_state_dict[key] = (base_tensor + merged_diff_tensor).to(dtype=orig_dtype) return out_state_dict ================================================ FILE: src/invoke_training/model_merge/scripts/extract_lora_from_model_diff.py ================================================ # This script is based on # https://raw.githubusercontent.com/kohya-ss/sd-scripts/bfb352bc433326a77aca3124248331eb60c49e8c/networks/extract_lora_from_models.py # That script was originally based on https://github.com/cloneofsimo/lora/blob/develop/lora_diffusion/cli_svd.py import argparse import logging import sys from dataclasses import dataclass from pathlib import Path from typing import Literal import peft import torch from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, UNet2DConditionModel from transformers import CLIPTextModel, CLIPTextModelWithProjection from invoke_training._shared.accelerator.accelerator_utils import get_dtype_from_str from invoke_training._shared.stable_diffusion.lora_checkpoint_utils import ( TEXT_ENCODER_TARGET_MODULES, UNET_TARGET_MODULES, save_sdxl_kohya_checkpoint, ) from invoke_training._shared.stable_diffusion.model_loading_utils import ( PipelineVersionEnum, from_pretrained_with_variant_fallback, load_pipeline, ) from invoke_training.model_merge.extract_lora import ( PEFT_BASE_LAYER_PREFIX, extract_lora_from_diffs, get_patched_base_weights_from_peft_model, get_state_dict_diff, ) from invoke_training.model_merge.utils.parse_model_arg import parse_model_arg @dataclass class StableDiffusionModel: """A helper class to store the submodels of a SD model that we are interested in for LoRA extraction.""" unet: UNet2DConditionModel | None = None text_encoder: CLIPTextModel | None = None text_encoder_2: CLIPTextModelWithProjection | None = None def all_none(self) -> bool: return self.unet is None and self.text_encoder is None and self.text_encoder_2 is None def load_model( logger: logging.Logger, model_name_or_path: str, model_type: PipelineVersionEnum, variant: str | None, dtype: torch.dtype, ) -> StableDiffusionModel: sd_model = StableDiffusionModel() model_path = Path(model_name_or_path) if model_path.is_dir(): # model_path is a directory, so we'll try to load the submodels of interest from its subdirectories. logger.info(f"'{model_name_or_path}' is a directory. Attempting to load submodels.") for submodel_name, submodel_class in [ ("unet", UNet2DConditionModel), ("text_encoder", CLIPTextModel), ("text_encoder_2", CLIPTextModelWithProjection), ]: submodel_path: Path = model_path / submodel_name if submodel_path.exists(): logger.info(f"Loading '{submodel_name}' from '{submodel_path}'.") submodel = from_pretrained_with_variant_fallback( logger=logger, model_class=submodel_class, model_name_or_path=submodel_path, torch_dtype=dtype, variant=variant, local_files_only=True, ) setattr(sd_model, submodel_name, submodel) else: logger.info(f"'{submodel_name}' not found in '{model_name_or_path}'. Skipping.") continue else: # model_name_or_path is not a directory, so it is either: # 1) a single checkpoint file # 2) a HF model name # Both can be loaded by calling load_pipeline. logger.info(f"'{model_name_or_path}' is a single checkpoint file. Attempting to load.") pipeline = load_pipeline( logger=logger, model_name_or_path=model_name_or_path, pipeline_version=model_type, torch_dtype=dtype, variant=variant, ) if isinstance(pipeline, StableDiffusionPipeline): sd_model.unet = pipeline.unet sd_model.text_encoder = pipeline.text_encoder elif isinstance(pipeline, StableDiffusionXLPipeline): sd_model.unet = pipeline.unet sd_model.text_encoder = pipeline.text_encoder sd_model.text_encoder_2 = pipeline.text_encoder_2 else: raise RuntimeError(f"Unexpected pipeline type: {type(pipeline)}.") if sd_model.all_none(): raise RuntimeError(f"Failed to load any submodels from '{model_name_or_path}'.") return sd_model def str_to_device(device_str: Literal["cuda", "cpu"]) -> torch.device: if device_str == "cuda": return torch.device("cuda") elif device_str == "cpu": return torch.device("cpu") else: raise ValueError(f"Unexpected device: {device_str}") def state_dict_to_device(state_dict: dict[str, torch.Tensor], device: torch.device) -> dict[str, torch.Tensor]: return {k: v.to(device=device) for k, v in state_dict.items()} def extract_lora_from_submodel( logger: logging.Logger, model_orig: torch.nn.Module, model_tuned: torch.nn.Module, device: torch.device, out_dtype: torch.dtype, lora_target_modules: list[str], lora_rank: int, clamp_quantile: float = 0.99, ) -> peft.PeftModel: """Extract LoRA weights from the diff between model_orig and model_tuned. Returns a new model_orig, wrapped in a PeftModel, with the LoRA weights applied. """ # Apply LoRA to the UNet. # The only reason we do this is to get the module names for the weights that we'll extract. We don't actually use # the LoRA weights initialized here. unet_lora_config = peft.LoraConfig( r=lora_rank, # We set the alpha to the rank, because we don't want any scaling to be applied to the LoRA weights that we # extract. lora_alpha=lora_rank, target_modules=lora_target_modules, ) model_tuned = peft.get_peft_model(model_tuned, unet_lora_config) model_orig = peft.get_peft_model(model_orig, unet_lora_config) base_weights_tuned = get_patched_base_weights_from_peft_model(model_tuned) base_weights_orig = get_patched_base_weights_from_peft_model(model_orig) diffs = get_state_dict_diff(base_weights_tuned, base_weights_orig) # Clear tuned model to save memory. # TODO(ryand): We also need to clear the state_dicts. Move the diff extraction to a separate function so that memory # cleanup is handled by scoping. del model_tuned # Apply SVD (Singluar Value Decomposition) to the diffs. # We just use the device for this calculation, since it's slow, then we move the results back to the CPU. logger.info("Calculating LoRA weights with SVD.") diffs = state_dict_to_device(diffs, device) # TODO(ryand): Should we skip if the diffs are all zeros? This would happen if two models are identical. This could # happen if some submodels differ while others don't. lora_weights = extract_lora_from_diffs( diffs=diffs, rank=lora_rank, clamp_quantile=clamp_quantile, out_dtype=out_dtype ) # Prepare state dict for LoRA. lora_state_dict = {} for module_name, (lora_up, lora_down) in lora_weights.items(): lora_state_dict[PEFT_BASE_LAYER_PREFIX + module_name + ".lora_A.default.weight"] = lora_down lora_state_dict[PEFT_BASE_LAYER_PREFIX + module_name + ".lora_B.default.weight"] = lora_up # The alpha value is set once globally in the PEFT model, so no need to set it for each module. # lora_state_dict[peft_base_layer_suffix + module_name + ".alpha"] = torch.tensor(down_weight.size()[0]) lora_state_dict = state_dict_to_device(lora_state_dict, torch.device("cpu")) # Load the state_dict into the LoRA model. model_orig.load_state_dict(lora_state_dict, strict=False, assign=True) return model_orig @torch.no_grad() def extract_lora( logger: logging.Logger, model_type: PipelineVersionEnum, orig_model_name_or_path: str, orig_model_variant: str | None, tuned_model_name_or_path: str, tuned_model_variant: str | None, save_to: str, load_precision: Literal["float32", "float16", "bfloat16"], save_precision: Literal["float32", "float16", "bfloat16"], device: Literal["cuda", "cpu"], lora_rank: int, clamp_quantile=0.99, ): load_dtype = get_dtype_from_str(load_precision) save_dtype = get_dtype_from_str(save_precision) device = str_to_device(device) orig_model = load_model( logger=logger, model_name_or_path=orig_model_name_or_path, model_type=model_type, dtype=load_dtype, variant=orig_model_variant, ) tuned_model = load_model( logger=logger, model_name_or_path=tuned_model_name_or_path, model_type=model_type, dtype=load_dtype, variant=tuned_model_variant, ) lora_models: dict[str, peft.PeftModel] = {} for submodel_name, submodel_orig, submodel_tuned, lora_target_modules in [ ("unet", orig_model.unet, tuned_model.unet, UNET_TARGET_MODULES), ("text_encoder", orig_model.text_encoder, tuned_model.text_encoder, TEXT_ENCODER_TARGET_MODULES), ("text_encoder_2", orig_model.text_encoder_2, tuned_model.text_encoder_2, TEXT_ENCODER_TARGET_MODULES), ]: if submodel_orig is not None and submodel_tuned is not None: logger.info(f"Extracting LoRA weights for '{submodel_name}'.") lora_models[submodel_name] = extract_lora_from_submodel( logger=logger, model_orig=submodel_orig, model_tuned=submodel_tuned, device=device, out_dtype=save_dtype, lora_target_modules=lora_target_modules, lora_rank=lora_rank, clamp_quantile=clamp_quantile, ) else: logger.info(f"Skipping '{submodel_name}'.") # Save the LoRA weights. save_to_path = Path(save_to) assert save_to_path.suffix == ".safetensors" if save_to_path.exists(): raise FileExistsError(f"Destination file already exists: '{save_to}'.") save_to_path.parent.mkdir(parents=True, exist_ok=True) save_sdxl_kohya_checkpoint( save_to_path, unet=lora_models.get("unet", None), text_encoder_1=lora_models.get("text_encoder", None), text_encoder_2=lora_models.get("text_encoder_2", None), ) logger.info(f"Saved LoRA weights to: {save_to_path}") def main(): parser = argparse.ArgumentParser() parser.add_argument( "--model-type", type=str, choices=["SD", "SDXL"], help="The type of the models to merge ['SD', 'SDXL'].", ) parser.add_argument( "--model-orig", type=str, required=True, help="Path or HF Hub name of the original model. The model must be in one of the following formats: " "1) a single checkpoint file (e.g. '.safetensors') containing all submodels, " "2) a model in diffusers format containing all submodels, " "or 3) a model in diffusers format containing a subset of the submodels (e.g. only a UNet)." "An HF variant can optionally be appended to the model name after a double-colon delimiter ('::')." "E.g. '--model-orig runwayml/stable-diffusion-v1-5::fp16'", ) parser.add_argument( "--model-tuned", type=str, required=True, help="Path or HF Hub name of the tuned model. The model must be in one of the following formats: " "1) a single checkpoint file (e.g. '.safetensors') containing all submodels, " "2) a model in diffusers format containing all submodels, " "or 3) a model in diffusers format containing a subset of the submodels (e.g. only a UNet)." "An HF variant can optionally be appended to the model name after a double-colon delimiter ('::')." "E.g. '--model-orig runwayml/stable-diffusion-v1-5::fp16'", ) parser.add_argument( "--save-to", type=str, required=True, help="Destination file path (must have a .safetensors extension).", ) parser.add_argument( "--load-precision", type=str, default="bfloat16", choices=["float32", "float16", "bfloat16"], help="Model load precision.", ) parser.add_argument( "--save-precision", type=str, default="float16", choices=["float32", "float16", "bfloat16"], help="Model save precision.", ) parser.add_argument("--lora-rank", type=int, default=4, help="LoRA rank dimension.") parser.add_argument("--clamp-quantile", type=float, default=0.99, help="Quantile clamping value. (0-1)") parser.add_argument( "--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device to use. (cuda or cpu)" ) args = parser.parse_args() logging.basicConfig(stream=sys.stdout, level=logging.INFO) logger = logging.getLogger() orig_model_name_or_path, orig_model_variant = parse_model_arg(args.model_orig) tuned_model_name_or_path, tuned_model_variant = parse_model_arg(args.model_tuned) extract_lora( logger=logger, model_type=PipelineVersionEnum(args.model_type), orig_model_name_or_path=orig_model_name_or_path, orig_model_variant=orig_model_variant, tuned_model_name_or_path=tuned_model_name_or_path, tuned_model_variant=tuned_model_variant, save_to=args.save_to, load_precision=args.load_precision, save_precision=args.save_precision, device=args.device, lora_rank=args.lora_rank, clamp_quantile=args.clamp_quantile, ) if __name__ == "__main__": main() ================================================ FILE: src/invoke_training/model_merge/scripts/merge_lora_into_model.py ================================================ import argparse # noqa: I001 import logging from pathlib import Path import torch from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline # fmt: off # HACK(ryand): Import order matters, because invokeai contains circular imports. from invokeai.backend.model_manager.taxonomy import BaseModelType from invokeai.backend.patches.layer_patcher import LayerPatcher from invokeai.backend.patches.lora_conversions.sd_lora_conversion_utils import \ lora_model_from_sd_state_dict from invokeai.backend.util.original_weights_storage import \ OriginalWeightsStorage from safetensors.torch import load_file # fmt: on from invoke_training._shared.accelerator.accelerator_utils import get_dtype_from_str from invoke_training._shared.stable_diffusion.model_loading_utils import PipelineVersionEnum, load_pipeline from invoke_training.model_merge.utils.parse_model_arg import parse_model_arg def to_invokeai_base_model_type(model_type: PipelineVersionEnum): if model_type == PipelineVersionEnum.SD: return BaseModelType.StableDiffusion1 elif model_type == PipelineVersionEnum.SDXL: return BaseModelType.StableDiffusionXL else: raise ValueError(f"Unexpected model_type: {model_type}") @torch.no_grad() def merge_lora_into_sd_model( logger: logging.Logger, model_type: PipelineVersionEnum, base_model: str, base_model_variant: str | None, lora_models: list[tuple[str, float]], output: str, save_dtype: str, ): pipeline: StableDiffusionXLPipeline | StableDiffusionPipeline = load_pipeline( logger=logger, model_name_or_path=base_model, pipeline_version=model_type, variant=base_model_variant ) save_dtype = get_dtype_from_str(save_dtype) logger.info(f"Loaded base model: '{base_model}'.") pipeline.to(save_dtype) models: list[torch.nn.Module] = [] lora_prefixes: list[str] = [] if isinstance(pipeline, StableDiffusionPipeline): models = [pipeline.unet, pipeline.text_encoder] lora_prefixes = ["lora_unet_", "lora_te_"] elif isinstance(pipeline, StableDiffusionXLPipeline): models = [pipeline.unet, pipeline.text_encoder, pipeline.text_encoder_2] lora_prefixes = ["lora_unet_", "lora_te1_", "lora_te2_"] else: raise ValueError(f"Unexpected pipeline type: {type(pipeline)}") # Although we are not unpatching, the patcher might require this. Initialize empty. original_weights = OriginalWeightsStorage() for lora_model_path, lora_model_weight in lora_models: # Load state dict from file lora_path = Path(lora_model_path) if lora_path.suffix == ".safetensors": state_dict = load_file(lora_path.absolute().as_posix(), device="cpu") else: # Assuming .ckpt, .pt, .bin etc. are torch checkpoints state_dict = torch.load(lora_path, map_location="cpu") # Convert state dict to ModelPatchRaw lora_model = lora_model_from_sd_state_dict(state_dict=state_dict) # Apply the patch using LayerPatcher for model, lora_prefix in zip(models, lora_prefixes, strict=True): LayerPatcher.apply_smart_model_patch( model=model, prefix=lora_prefix, patch=lora_model, patch_weight=lora_model_weight, original_weights=original_weights, # Pass storage, even if unused for merging original_modules={}, # Pass empty dict, not needed for direct patching/merging dtype=model.dtype, # Use the model's dtype # Force direct patching since we are merging into the main weights force_direct_patching=True, force_sidecar_patching=False, ) logger.info(f"Applied LoRA model '{lora_model_path}' with weight {lora_model_weight}.") output_path = Path(output) output_path.mkdir(parents=True) # TODO(ryand): Should we keep the base model variant? This is clearly a flawed assumption. pipeline.save_pretrained(output_path, variant=base_model_variant) logger.info(f"Saved merged model to '{output_path}'.") def parse_lora_model_arg(lora_model_arg: str) -> tuple[str, float]: """Parse a --lora-model argument into a tuple of the model path and weight.""" parts = lora_model_arg.split("::") if len(parts) == 1: return parts[0], 1.0 elif len(parts) == 2: return parts[0], float(parts[1]) else: raise ValueError(f"Unexpected format for --lora-model arg: '{lora_model_arg}'.") def main(): parser = argparse.ArgumentParser() parser.add_argument( "--model-type", type=str, choices=["SD", "SDXL"], help="The type of the models to merge ['SD', 'SDXL'].", ) parser.add_argument( "--base-model", type=str, help="The base model to merge LoRAs into. The model can be either 1) an HF hub name, 2) a path to a local " "diffusers model directory, or 3) a path to a single checkpoint file. An HF variant can optionally be appended " "to the model name after a double-colon delimiter ('::')." "E.g. '--base-model runwayml/stable-diffusion-v1-5::fp16'", required=True, ) parser.add_argument( "--lora-models", type=str, nargs="+", help="The path(s) to one or more LoRA models to merge into the base model. Model weights can be appended to " "the path, separated by a double colon ('::'). The weight is optional and defaults to 1.0. E.g. " "'--lora-models path/to/lora_model_1.safetensors::0.5 path/to/lora_model_2.safetensors'.", required=True, ) parser.add_argument( "--output", type=str, help="The path to an output directory where the merged model will be saved (in diffusers format).", ) parser.add_argument( "--save-dtype", type=str, default="float16", choices=["float32", "float16", "bfloat16"], help="The dtype to save the model as.", ) args = parser.parse_args() logging.basicConfig(level=logging.INFO) logger = logging.getLogger() base_model, base_model_variant = parse_model_arg(args.base_model) lora_models = [parse_lora_model_arg(arg) for arg in args.lora_models] # Log the parsed arguments logger.info(f"Model type: {args.model_type}") logger.info(f"Base model: {base_model}") logger.info(f"Base model variant: {base_model_variant}") logger.info(f"Output directory: {args.output}") logger.info(f"Save dtype: {args.save_dtype}") lora_models_str = " - " + "\n - ".join([f"{model} ({weight})" for model, weight in lora_models]) logger.info(f"LoRA models:\n{lora_models_str}") merge_lora_into_sd_model( logger=logger, model_type=PipelineVersionEnum(args.model_type), base_model=base_model, base_model_variant=base_model_variant, lora_models=lora_models, output=args.output, save_dtype=args.save_dtype, ) if __name__ == "__main__": main() ================================================ FILE: src/invoke_training/model_merge/scripts/merge_models.py ================================================ import argparse import logging from dataclasses import dataclass from pathlib import Path import torch from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline from invoke_training._shared.accelerator.accelerator_utils import get_dtype_from_str from invoke_training._shared.stable_diffusion.model_loading_utils import PipelineVersionEnum, load_pipeline from invoke_training.model_merge.merge_models import merge_models from invoke_training.model_merge.utils.parse_model_arg import parse_model_arg @dataclass class MergeModel: model_name_or_path: str variant: str | None weight: float def run_merge_models( logger: logging.Logger, model_type: PipelineVersionEnum, models: list[MergeModel], method: str, out_dir: str, dtype: torch.dtype, ): # Create the output directory if it doesn't exist. out_dir_path = Path(out_dir) out_dir_path.mkdir(parents=True, exist_ok=False) # Load the models. loaded_models: list[StableDiffusionPipeline] | list[StableDiffusionXLPipeline] = [] for model in models: loaded_model = load_pipeline( logger=logger, model_name_or_path=model.model_name_or_path, pipeline_version=model_type, torch_dtype=dtype, variant=model.variant, ) loaded_models.append(loaded_model) # Select the submodels to merge. if model_type == PipelineVersionEnum.SDXL: submodel_names = ["unet", "text_encoder", "text_encoder_2"] elif model_type == PipelineVersionEnum.SD: submodel_names = ["unet", "text_encoder"] else: raise ValueError(f"Unexpected model type: {model_type}") # Merge the models. weights = [model.weight for model in models] for submodel_name in submodel_names: submodels: list[torch.nn.Module] = [getattr(loaded_model, submodel_name) for loaded_model in loaded_models] submodel_state_dicts: list[dict[str, torch.Tensor]] = [submodel.state_dict() for submodel in submodels] logger.info(f"Merging {submodel_name} state_dicts...") merged_state_dict = merge_models(state_dicts=submodel_state_dicts, weights=weights, merge_method=method) # Merge the merged_state_dict back into the first pipeline to keep memory utilization low. submodels[0].load_state_dict(merged_state_dict, assign=True) logger.info(f"Merged {submodel_name} state_dicts.") # Save the merged model. logger.info("Saving result...") loaded_models[0].save_pretrained(out_dir_path) logger.info(f"Saved merged model to '{out_dir_path}'.") def parse_model_args(models: list[str], weights: list[str]) -> list[MergeModel]: """Parse a list of --models arguments and --weights arguments into a list of MergeModels.""" merge_model_list: list[MergeModel] = [] for model, weight in zip(models, weights, strict=True): parsed_model, parsed_variant = parse_model_arg(model) merge_model_list.append( MergeModel(model_name_or_path=parsed_model, variant=parsed_variant, weight=float(weight)) ) return merge_model_list def main(): parser = argparse.ArgumentParser() # TODO(ryand): Auto-detect the model-type. parser.add_argument( "--model-type", type=str, choices=["SD", "SDXL"], help="The type of the models to merge ['SD', 'SDXL'].", ) parser.add_argument( "--models", nargs="+", type=str, required=True, help="Two or more models to merge. Each model can be either 1) an HF hub name, 2) a path to a local diffusers " "model directory, or 3) a path to a single checkpoint file. An HF variant can optionally be appended to the " "model name after a double-colon delimiter ('::')." "E.g. '--models runwayml/stable-diffusion-v1-5::fp16 path/to/local/model.safetensors'", ) parser.add_argument( "--weights", nargs="+", type=float, required=True, help="The weights for each model. The weights will be normalized to sum to 1. " "For example, to merge weights with equal weights: '--weights 1.0 1.0'. " "To weight the first model more heavily: '--weights 0.75 0.25'.", ) parser.add_argument( "--method", type=str, default="LERP", choices=["LERP", "SLERP"], help="The merge method to use. Options: 'LERP' (linear interpolation) or 'SLERP' (spherical linear " "interpolation).", ) parser.add_argument( "--out-dir", type=str, required=True, help="The output directory where the merged model will be written (in diffusers format).", ) parser.add_argument( "--dtype", help="The torch dtype that will be used for all calculations and for the output model.", type=str, default="float16", choices=["float32", "float16", "bfloat16"], ) args = parser.parse_args() logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) merge_model_list = parse_model_args(args.models, args.weights) run_merge_models( logger=logger, model_type=PipelineVersionEnum(args.model_type), models=merge_model_list, method=args.method, out_dir=args.out_dir, dtype=get_dtype_from_str(args.dtype), ) if __name__ == "__main__": main() ================================================ FILE: src/invoke_training/model_merge/scripts/merge_task_models_to_base_model.py ================================================ import argparse import logging from pathlib import Path import torch from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline from invoke_training._shared.accelerator.accelerator_utils import get_dtype_from_str from invoke_training._shared.stable_diffusion.model_loading_utils import PipelineVersionEnum, load_pipeline from invoke_training.model_merge.merge_tasks_to_base import merge_tasks_to_base_model from invoke_training.model_merge.scripts.merge_models import MergeModel, parse_model_args def run_merge_models( logger: logging.Logger, model_type: PipelineVersionEnum, base_model: MergeModel, task_models: list[MergeModel], method: str, density: float, out_dir: str, dtype: torch.dtype, ): # Create the output directory if it doesn't exist. out_dir_path = Path(out_dir) out_dir_path.mkdir(parents=True, exist_ok=False) # Load the base model. loaded_base_model = load_pipeline( logger=logger, model_name_or_path=base_model.model_name_or_path, pipeline_version=model_type, torch_dtype=dtype, variant=base_model.variant, ) # Load the task models. loaded_task_models: list[StableDiffusionPipeline] | list[StableDiffusionXLPipeline] = [] for task_model in task_models: loaded_task_model = load_pipeline( logger=logger, model_name_or_path=task_model.model_name_or_path, pipeline_version=model_type, torch_dtype=dtype, variant=task_model.variant, ) loaded_task_models.append(loaded_task_model) # Select the submodels to merge. if model_type == PipelineVersionEnum.SDXL: submodel_names = ["unet", "text_encoder", "text_encoder_2"] elif model_type == PipelineVersionEnum.SD: submodel_names = ["unet", "text_encoder"] else: raise ValueError(f"Unexpected model type: {model_type}") # Merge the models. task_model_weights = [task_model.weight for task_model in task_models] for submodel_name in submodel_names: base_submodel: torch.nn.Module = getattr(loaded_base_model, submodel_name) base_submodel_state_dict = base_submodel.state_dict() task_submodels: list[torch.nn.Module] = [ getattr(loaded_task_model, submodel_name) for loaded_task_model in loaded_task_models ] task_submodel_state_dict = [submodel.state_dict() for submodel in task_submodels] logger.info(f"Merging {submodel_name} state_dicts...") merged_state_dict = merge_tasks_to_base_model( base_state_dict=base_submodel_state_dict, task_state_dicts=task_submodel_state_dict, task_weights=task_model_weights, density=density, merge_method=method, ) # Merge the merged_state_dict back into the base model pipeline to keep memory utilization low. base_submodel.load_state_dict(merged_state_dict, assign=True) logger.info(f"Merged {submodel_name} state_dicts.") # Delete the task models to free up memory. # At the time of the writing, the save_pretrained(...) function below caused a large spike in memory usage. We free # the task models to increase its likelihood of success. del loaded_task_models # Save the merged model. logger.info("Saving result...") loaded_base_model.save_pretrained(out_dir_path) logger.info(f"Saved merged model to '{out_dir_path}'.") def main(): parser = argparse.ArgumentParser() # TODO(ryand): Auto-detect the base-model-type. parser.add_argument( "--model-type", type=str, choices=["SD", "SDXL"], help="The type of the models to merge ['SD', 'SDXL'].", ) parser.add_argument( "--base-model", type=str, help="The base model to merge task-specific models into. Can be either 1) an HF hub name, 2) a path to a local " "diffusers model directory, or 3) a path to a single checkpoint file. An HF variant can optionally be appended " "to the model name after a double-colon delimiter ('::')." "E.g. '--base-model runwayml/stable-diffusion-v1-5::fp16'.", ) parser.add_argument( "--task-models", nargs="+", type=str, required=True, help="One or more task-specific models to merge into the base model. Each model can be either 1) an HF hub " "name, 2) a path to a local diffusers model directory, or 3) a path to a single checkpoint file. An HF variant " "can optionally be appended to the model name after a double-colon delimiter ('::')." "E.g. '--task-models runwayml/stable-diffusion-v1-5::fp16 path/to/local/model.safetensors'", ) parser.add_argument( "--task-weights", nargs="+", type=float, required=True, help="The weights for each task model. The weights are multipliers applied to the diff between each task model " "and the base model. As a starting point, it is recommended to use a weight of 1.0 for all task models, e.g. " "'--task-weights 1.0 1.0'. The weights can then be tuned from there, e.g. '--task-weights 1.0 1.3'.", ) parser.add_argument( "--method", type=str, default="TIES", choices=["TIES", "DARE_LINEAR", "DARE_TIES"], help="The merge method to use. Options: ['TIES', 'DARE_LINEAR', 'DARE_TIES'].", ) parser.add_argument( "--density", type=float, default=0.2, help="The fraction of values to preserve in the prune/trim step of DARE/TIES methods. Should be in the range " "[0, 1].", ) parser.add_argument( "--out-dir", type=str, required=True, help="The output directory where the merged model will be written (in diffusers format).", ) parser.add_argument( "--dtype", help="The torch dtype that will be used for all calculations and for the output model.", type=str, default="float16", choices=["float32", "float16", "bfloat16"], ) args = parser.parse_args() logging.basicConfig(level=logging.INFO) logger = logging.getLogger() base_model = parse_model_args([args.base_model], [1.0])[0] task_models = parse_model_args(args.task_models, args.task_weights) run_merge_models( logger=logger, model_type=PipelineVersionEnum(args.model_type), base_model=base_model, task_models=task_models, method=args.method, density=args.density, out_dir=args.out_dir, dtype=get_dtype_from_str(args.dtype), ) if __name__ == "__main__": main() ================================================ FILE: src/invoke_training/model_merge/utils/__init__.py ================================================ ================================================ FILE: src/invoke_training/model_merge/utils/normalize_weights.py ================================================ def normalize_weights(weights: list[float]) -> list[float]: total = sum(weights) return [weight / total for weight in weights] ================================================ FILE: src/invoke_training/model_merge/utils/parse_model_arg.py ================================================ def parse_model_arg(model: str, delimiter: str = "::") -> tuple[str, str | None]: """Parse a model argument into a model and a variant.""" parts = model.split(delimiter) if len(parts) == 1: return parts[0], None elif len(parts) == 2: return parts[0], parts[1] else: raise ValueError(f"Unexpected format for --models arg: '{model}'.") ================================================ FILE: src/invoke_training/pipelines/__init__.py ================================================ ================================================ FILE: src/invoke_training/pipelines/_experimental/sd_dpo_lora/config.py ================================================ from typing import Annotated, Literal, Union from pydantic import Field, model_validator from invoke_training.config.base_pipeline_config import BasePipelineConfig from invoke_training.config.config_base_model import ConfigBaseModel from invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig, ProdigyOptimizerConfig class HFHubImagePairPreferenceDatasetConfig(ConfigBaseModel): type: Literal["HF_HUB_IMAGE_PAIR_PREFERENCE_DATASET"] = "HF_HUB_IMAGE_PAIR_PREFERENCE_DATASET" # TODO(ryand): Fill this out. class ImagePairPreferenceDatasetConfig(ConfigBaseModel): type: Literal["IMAGE_PAIR_PREFERENCE_DATASET"] = "IMAGE_PAIR_PREFERENCE_DATASET" dataset_dir: str """The directory to load the dataset from.""" class ImagePairPreferenceSDDataLoaderConfig(ConfigBaseModel): type: Literal["IMAGE_PAIR_PREFERENCE_SD_DATA_LOADER"] = "IMAGE_PAIR_PREFERENCE_SD_DATA_LOADER" dataset: Annotated[ Union[HFHubImagePairPreferenceDatasetConfig, ImagePairPreferenceDatasetConfig], Field(discriminator="type") ] resolution: int | tuple[int, int] = 512 """The resolution for input images. Either a scalar integer representing the square resolution height and width, or a (height, width) tuple. All of the images in the dataset will be resized to this resolution unless the `aspect_ratio_buckets` config is set. """ center_crop: bool = True """If True, input images will be center-cropped to the target resolution. If False, input images will be randomly cropped to the target resolution. """ random_flip: bool = False """Whether random flip augmentations should be applied to input images. """ dataloader_num_workers: int = 0 """Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process. """ class SdDirectPreferenceOptimizationLoraConfig(BasePipelineConfig): type: Literal["SD_DIRECT_PREFERENCE_OPTIMIZATION_LORA"] = "SD_DIRECT_PREFERENCE_OPTIMIZATION_LORA" model: str = "runwayml/stable-diffusion-v1-5" """Name or path of the base model to train. Can be in diffusers format, or a single stable diffusion checkpoint file. (E.g. 'runwayml/stable-diffusion-v1-5', '/path/to/realisticVisionV51_v51VAE.safetensors', etc. ) """ hf_variant: str | None = "fp16" """The Hugging Face Hub model variant to use. Only applies if `model` is a Hugging Face Hub model name. """ # Note: Pydantic handles mutable default values well: # https://docs.pydantic.dev/latest/concepts/models/#fields-with-non-hashable-default-values base_embeddings: dict[str, str] = {} """A mapping of embedding tokens to trained embedding file paths. These embeddings will be applied to the base model before training. Example: ``` base_embeddings = { "bruce_the_gnome": "/path/to/bruce_the_gnome.safetensors", } ``` Consider also adding the embedding tokens to the `data_loader.caption_prefix` if they are not already present in the dataset captions. Note that the embeddings themselves are not fine-tuned further, but they will impact the LoRA model training if they are referenced in the dataset captions. The list of embeddings provided here should be the same list used at generation time with the resultant LoRA model. """ lora_checkpoint_format: Literal["invoke_peft", "kohya"] = "kohya" """The format of the LoRA checkpoint to save. Choose between `invoke_peft` or `kohya`.""" train_unet: bool = True """Whether to add LoRA layers to the UNet model and train it. """ train_text_encoder: bool = True """Whether to add LoRA layers to the text encoder and train it. """ optimizer: AdamOptimizerConfig | ProdigyOptimizerConfig = AdamOptimizerConfig() text_encoder_learning_rate: float | None = None """The learning rate to use for the text encoder model. If set, this overrides the optimizer's default learning rate. Set to null or 0 to use the optimizer's default learning rate. """ unet_learning_rate: float | None = None """The learning rate to use for the UNet model. If set, this overrides the optimizer's default learning rate. Set to null or 0 to use the optimizer's default learning rate. """ lr_scheduler: Literal[ "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup" ] = "constant" lr_warmup_steps: int = 0 """The number of warmup steps in the learning rate scheduler. Only applied to schedulers that support warmup. See lr_scheduler. """ min_snr_gamma: float | None = 5.0 """Min-SNR weighting for diffusion training was introduced in https://arxiv.org/abs/2303.09556. This strategy improves the speed of training convergence by adjusting the weight of each sample. `min_snr_gamma` acts like an an upper bound on the weight of samples with low noise levels. If `None`, then Min-SNR weighting will not be applied. If enabled, the recommended value is `min_snr_gamma = 5.0`. """ lora_rank_dim: int = 4 """The rank dimension to use for the LoRA layers. Increasing the rank dimension increases the model's expressivity, but also increases the size of the generated LoRA model. """ cache_text_encoder_outputs: bool = False """If True, the text encoder(s) will be applied to all of the captions in the dataset before starting training and the results will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the text encoders in VRAM), and speeds up training (don't have to run the text encoders for each training example). This option can only be enabled if `train_text_encoder == False` and there are no caption augmentations being applied. """ cache_vae_outputs: bool = False """If True, the VAE will be applied to all of the images in the dataset before starting training and the results will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the VAE in VRAM), and speeds up training (don't have to run the VAE encoding step). This option can only be enabled if all non-deterministic image augmentations are disabled (i.e. center_crop=True, random_flip=False). """ enable_cpu_offload_during_validation: bool = False """If True, models will be kept in CPU memory and loaded into GPU memory one-by-one while generating validation images. This reduces VRAM requirements at the cost of slower generation of validation images. """ gradient_accumulation_steps: int = 1 """The number of gradient steps to accumulate before each weight update. This value is passed to Hugging Face Accelerate. This is an alternative to increasing the batch size when training with limited VRAM. """ weight_dtype: Literal["float32", "float16", "bfloat16"] = "bfloat16" """All weights (trainable and fixed) will be cast to this precision. Lower precision dtypes require less VRAM, and result in faster training, but are more prone to issues with numerical stability. Recommendations: - `"float32"`: Use this mode if you have plenty of VRAM available. - `"bfloat16"`: Use this mode if you have limited VRAM and a GPU that supports bfloat16. - `"float16"`: Use this mode if you have limited VRAM and a GPU that does not support bfloat16. See also [`mixed_precision`][invoke_training.pipelines._experimental.sd_dpo_lora.config.SdDirectPreferenceOptimizationLoraConfig.mixed_precision]. """ # noqa: E501 mixed_precision: Literal["no", "fp16", "bf16", "fp8"] = "no" """The mixed precision mode to use. If mixed precision is enabled, then all non-trainable parameters will be cast to the specified `weight_dtype`, and trainable parameters are kept in float32 precision to avoid issues with numerical stability. This value is passed to Hugging Face Accelerate. See [`accelerate.Accelerator.mixed_precision`](https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.mixed_precision) for more details. """ # noqa: E501 xformers: bool = False """If true, use xformers for more efficient attention blocks. """ gradient_checkpointing: bool = False """Whether or not to use gradient checkpointing to save memory at the expense of a slower backward pass. Enabling gradient checkpointing slows down training by ~20%. """ max_checkpoints: int | None = None """The maximum number of checkpoints to keep. New checkpoints will replace earlier checkpoints to stay under this limit. Note that this limit is applied to 'step' and 'epoch' checkpoints separately. """ prediction_type: Literal["epsilon", "v_prediction"] | None = None """The prediction_type that will be used for training. Choose between 'epsilon' or 'v_prediction' or leave 'None'. If 'None', the prediction type of the scheduler: `noise_scheduler.config.prediction_type` is used. """ max_grad_norm: float | None = None """Max gradient norm for clipping. Set to null or 0 for no clipping. """ validation_prompts: list[str] = [] """A list of prompts that will be used to generate images throughout training for the purpose of tracking progress. See also 'validate_every_n_epochs'. """ negative_validation_prompts: list[str] | None = None """A list of negative prompts that will be applied when generating validation images. If set, this list should have the same length as 'validation_prompts'. """ num_validation_images_per_prompt: int = 4 """The number of validation images to generate for each prompt in 'validation_prompts'. Careful, validation can become quite slow if this number is too large. """ train_batch_size: int = 4 """The training batch size. """ data_loader: ImagePairPreferenceSDDataLoaderConfig initial_lora: str | None = None """The LoRA checkpoint directory to initialize the LoRA weights from. If set, the following configuration parameters are ignored: - `train_unet`: The UNet will be trained if it is present in `initial_lora`. - `train_text_encoder`: The text encoder will be trained if it is present in `initial_lora`. - `lora_rank_dim`: The LoRA rank dimension from `initial_lora` will be used. Currently only LoRA checkpoints in the internal `invoke-training` PEFT format are supported (i.e. checkpoints generated by an `invoke-training` training pipeline). """ beta: float = 5000.0 """The beta parameter, as defined in (https://arxiv.org/pdf/2311.12908.pdf). Larger beta values increase the KL-Divergence penalty, discouraging divergence from the reference model weights. Typical values for `beta` are in the range [1000.0, 10000.0]. """ @model_validator(mode="after") def check_validation_prompts(self): if self.negative_validation_prompts is not None and len(self.negative_validation_prompts) != len( self.validation_prompts ): raise ValueError( f"The number of validation_prompts ({len(self.validation_prompts)}) must match the number of " f"negative_validation_prompts ({len(self.negative_validation_prompts)})." ) return self ================================================ FILE: src/invoke_training/pipelines/_experimental/sd_dpo_lora/train.py ================================================ import copy import itertools import json import logging import math import os import tempfile import time from pathlib import Path from typing import Literal import peft import torch import torch.utils.data from accelerate.utils import set_seed from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel from diffusers.optimization import get_scheduler from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from invoke_training._shared.accelerator.accelerator_utils import ( get_dtype_from_str, initialize_accelerator, initialize_logging, ) from invoke_training._shared.checkpoints.checkpoint_tracker import CheckpointTracker from invoke_training._shared.data.data_loaders.image_pair_preference_sd_dataloader import ( build_image_pair_preference_sd_dataloader, ) from invoke_training._shared.optimizer.optimizer_utils import initialize_optimizer from invoke_training._shared.stable_diffusion.lora_checkpoint_utils import ( TEXT_ENCODER_TARGET_MODULES, UNET_TARGET_MODULES, load_sd_peft_checkpoint, save_sd_kohya_checkpoint, save_sd_peft_checkpoint, ) from invoke_training._shared.stable_diffusion.model_loading_utils import load_models_sd from invoke_training._shared.stable_diffusion.tokenize_captions import tokenize_captions from invoke_training._shared.stable_diffusion.validation import generate_validation_images_sd from invoke_training._shared.utils.import_xformers import import_xformers from invoke_training.pipelines._experimental.sd_dpo_lora.config import SdDirectPreferenceOptimizationLoraConfig from invoke_training.pipelines.callbacks import PipelineCallbacks from invoke_training.pipelines.stable_diffusion.lora.train import cache_text_encoder_outputs def _save_sd_lora_checkpoint( epoch: int, step: int, unet: peft.PeftModel | None, text_encoder: peft.PeftModel | None, logger: logging.Logger, checkpoint_tracker: CheckpointTracker, lora_checkpoint_format: Literal["invoke_peft", "kohya"], ): # Prune checkpoints and get new checkpoint path. num_pruned = checkpoint_tracker.prune(1) if num_pruned > 0: logger.info(f"Pruned {num_pruned} checkpoint(s).") save_path = checkpoint_tracker.get_path(epoch=epoch, step=step) if lora_checkpoint_format == "invoke_peft": save_sd_peft_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder) elif lora_checkpoint_format == "kohya": save_sd_kohya_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder) else: raise ValueError(f"Unsupported lora_checkpoint_format: '{lora_checkpoint_format}'.") def train_forward_dpo( # noqa: C901 config: SdDirectPreferenceOptimizationLoraConfig, data_batch: dict, vae: AutoencoderKL, noise_scheduler: DDPMScheduler, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, unet: UNet2DConditionModel, ref_text_encoder: CLIPTextModel, ref_unet: UNet2DConditionModel, weight_dtype: torch.dtype, ) -> torch.Tensor: """Run the forward training pass for a single data_batch. This forward pass is based on 'Diffusion Model Alignment Using Direct Preference Optimization' (https://arxiv.org/pdf/2311.12908.pdf). See the "Pseudocode for Training Objective" Appendix section for a helpful reference. Returns: torch.Tensor: Loss """ batch_size = data_batch["image_0"].shape[0] # Concatenate image_0 and image_1 images into a single image batch. images = torch.concat((data_batch["image_0"], data_batch["image_1"])) # Re-order images so that the 'images' batch contains all winner images followed by all loser images. w_indices = [] l_indices = [] prefer_0 = data_batch["prefer_0"] prefer_1 = data_batch["prefer_1"] for i in range(batch_size): if prefer_0[i] and not prefer_1[i]: w_indices.append(i) l_indices.append(i + batch_size) elif not prefer_0[i] and prefer_1[i]: w_indices.append(i + batch_size) l_indices.append(i) else: raise ValueError(f"Encountered image pair with prefer_0={prefer_0[i]} and prefer_1={prefer_1[i]}.") images = images[w_indices + l_indices] # Update batch_size in case image pairs were filtered due to no-preference. batch_size = images.shape[0] // 2 # Convert images to latent space. # The VAE output may have been cached and included in the data_batch. If not, we calculate it here. latents = data_batch.get("vae_output", None) if latents is None: latents = vae.encode(images.to(dtype=weight_dtype)).latent_dist.sample() latents = latents * vae.config.scaling_factor # Sample noise that we'll add to the latents. # We want to use the same noise for the winning and losing example in each pair, so we generate noise for the # winning latents and then repeat it. noise = torch.randn_like(latents[:batch_size]) noise = noise.repeat((2, 1, 1, 1)) # Sample a random timestep for each image **pair**. timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (batch_size,), device=latents.device) timesteps = timesteps.repeat((2,)).long() # Add noise to the latents according to the noise magnitude at each timestep (this is the forward # diffusion process). noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning (for both the text_encoder and ref_text_encoder). # The text_encoder_output may have been cached and included in the data_batch. If not, we calculate it here. encoder_hidden_states = data_batch.get("text_encoder_output", None) if encoder_hidden_states is None: caption_token_ids = tokenize_captions(tokenizer, data_batch["caption"]).to(text_encoder.device) encoder_hidden_states = text_encoder(caption_token_ids)[0].to(dtype=weight_dtype) ref_encoder_hidden_states = ref_text_encoder(caption_token_ids)[0].to(dtype=weight_dtype) encoder_hidden_states = encoder_hidden_states.repeat((2, 1, 1)) ref_encoder_hidden_states = ref_encoder_hidden_states.repeat((2, 1, 1)) # Get the target for loss depending on the prediction type. if config.prediction_type is not None: # Set the prediction_type of scheduler if it's defined in config. noise_scheduler.register_to_config(prediction_type=config.prediction_type) if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") # Predict the noise residual. ref_model_pred: torch.Tensor = ref_unet(noisy_latents, timesteps, ref_encoder_hidden_states).sample model_pred: torch.Tensor = unet(noisy_latents, timesteps, encoder_hidden_states).sample if "loss_weight" in data_batch: raise NotImplementedError("loss_weight is not yet supported.") target = target.float() w_target = target[:batch_size] l_target = target[batch_size:] model_w_pred = model_pred[:batch_size] model_l_pred = model_pred[batch_size:] ref_w_pred = ref_model_pred[:batch_size] ref_l_pred = ref_model_pred[batch_size:] # The pseudo-code from the paper uses `.norm().pow(2)` to calculate the errors. We take the mean over all pixels # rather than the sum over all pixels instead. This helps keep the learning rate stable across different image # resolutions. It also means that the the recommended settings for beta from the paper are not correct. # > model_w_err = (model_w_pred - target).norm().pow(2) # > model_l_err = (model_l_pred - target).norm().pow(2) # > ref_w_err = (ref_w_pred - target).norm().pow(2) # > ref_l_err = (ref_l_pred - target).norm().pow(2) model_w_err = torch.nn.functional.mse_loss(model_w_pred, w_target) model_l_err = torch.nn.functional.mse_loss(model_l_pred, l_target) ref_w_err = torch.nn.functional.mse_loss(ref_w_pred, w_target) ref_l_err = torch.nn.functional.mse_loss(ref_l_pred, l_target) w_diff = model_w_err - ref_w_err l_diff = model_l_err - ref_l_err inside_term = -1 * config.beta * (w_diff - l_diff) loss = -1 * torch.nn.functional.logsigmoid(inside_term) return loss def train(config: SdDirectPreferenceOptimizationLoraConfig, callbacks: list[PipelineCallbacks] | None = None): # noqa: C901 if callbacks: raise ValueError(f"This pipeline does not support callbacks, but {len(callbacks)} were provided.") # Give a clear error message if an unsupported base model was chosen. # TODO(ryan): Update this check to work with single-file SD checkpoints. # check_base_model_version( # {BaseModelVersionEnum.STABLE_DIFFUSION_V1, BaseModelVersionEnum.STABLE_DIFFUSION_V2}, # config.model, # local_files_only=False, # ) # Create a timestamped directory for all outputs. out_dir = os.path.join(config.base_output_dir, f"{time.time()}") ckpt_dir = os.path.join(out_dir, "checkpoints") os.makedirs(ckpt_dir) accelerator = initialize_accelerator( out_dir, config.gradient_accumulation_steps, config.mixed_precision, config.report_to ) logger = initialize_logging(os.path.basename(__file__), accelerator) # Set the accelerate seed. if config.seed is not None: set_seed(config.seed) # Log the accelerator configuration from every process to help with debugging. logger.info(accelerator.state, main_process_only=False) logger.info("Starting LoRA Training.") logger.info(f"Configuration:\n{json.dumps(config.dict(), indent=2, default=str)}") logger.info(f"Output dir: '{out_dir}'") # Write the configuration to disk. with open(os.path.join(out_dir, "config.json"), "w") as f: json.dump(config.dict(), f, indent=2, default=str) weight_dtype = get_dtype_from_str(config.weight_dtype) logger.info("Loading models.") tokenizer, noise_scheduler, text_encoder, vae, unet = load_models_sd( logger=logger, model_name_or_path=config.model, hf_variant=config.hf_variant, base_embeddings=config.base_embeddings, dtype=weight_dtype, ) ref_text_encoder = copy.deepcopy(text_encoder) ref_unet = copy.deepcopy(unet) if config.xformers: import_xformers() # TODO(ryand): There is a known issue if xformers is enabled when training in mixed precision where xformers # will fail because Q, K, V have different dtypes. unet.enable_xformers_memory_efficient_attention() ref_unet.enable_xformers_memory_efficient_attention() vae.enable_xformers_memory_efficient_attention() # Prepare text encoder output cache. text_encoder_output_cache_dir_name = None if config.cache_text_encoder_outputs: # TODO(ryand): Think about how to better check if it is safe to cache the text encoder outputs. Currently, there # are a number of configurations that would cause variation in the text encoder outputs and should not be used # with caching. # TODO(ryand): This check does not make sense when config.initial_lora is set. if config.train_text_encoder: raise ValueError("'cache_text_encoder_outputs' and 'train_text_encoder' cannot both be True.") # We use a temporary directory for the cache. The directory will automatically be cleaned up when # tmp_text_encoder_output_cache_dir is destroyed. tmp_text_encoder_output_cache_dir = tempfile.TemporaryDirectory() text_encoder_output_cache_dir_name = tmp_text_encoder_output_cache_dir.name if accelerator.is_local_main_process: # Only the main process should populate the cache. logger.info(f"Generating text encoder output cache ('{text_encoder_output_cache_dir_name}').") text_encoder.to(accelerator.device, dtype=weight_dtype) cache_text_encoder_outputs(text_encoder_output_cache_dir_name, config, tokenizer, text_encoder) # Move the text_encoder back to the CPU, because it is not needed for training. text_encoder.to("cpu") accelerator.wait_for_everyone() else: text_encoder.to(accelerator.device, dtype=weight_dtype) ref_text_encoder.to(accelerator.device, dtype=weight_dtype) # Prepare VAE output cache. vae_output_cache_dir_name = None if config.cache_vae_outputs: raise NotImplementedError("VAE caching is not implemented for Diffusion-DPO training yet.") # # We use a temporary directory for the cache. The directory will automatically be cleaned up when # # tmp_vae_output_cache_dir is destroyed. # tmp_vae_output_cache_dir = tempfile.TemporaryDirectory() # vae_output_cache_dir_name = tmp_vae_output_cache_dir.name # if accelerator.is_local_main_process: # # Only the main process should populate the cache. # logger.info(f"Generating VAE output cache ('{vae_output_cache_dir_name}').") # vae.to(accelerator.device, dtype=weight_dtype) # data_loader = build_data_loader( # data_loader_config=config.data_loader, # batch_size=config.train_batch_size, # shuffle=False, # sequential_batching=True, # ) # cache_vae_outputs(vae_output_cache_dir_name, data_loader, vae) # # Move the VAE back to the CPU, because it is not needed for training. # vae.to("cpu") # accelerator.wait_for_everyone() else: vae.to(accelerator.device, dtype=weight_dtype) unet.to(accelerator.device, dtype=weight_dtype) ref_unet.to(accelerator.device, dtype=weight_dtype) # Add LoRA layers to the models being trained. trainable_param_groups = [] all_trainable_models: list[peft.PeftModel] = [] # Add LoRA layers to the model. trainable_param_groups = [] if config.initial_lora is not None: unet, text_encoder = load_sd_peft_checkpoint( checkpoint_dir=config.initial_lora, unet=unet, text_encoder=text_encoder, is_trainable=True ) ref_unet, ref_text_encoder = load_sd_peft_checkpoint( checkpoint_dir=config.initial_lora, unet=ref_unet, text_encoder=ref_text_encoder, is_trainable=False ) else: if config.train_unet: unet_lora_config = peft.LoraConfig( r=config.lora_rank_dim, # TODO(ryand): Diffusers uses lora_alpha=config.lora_rank_dim. Is that preferred? lora_alpha=1.0, target_modules=UNET_TARGET_MODULES, ) unet = peft.get_peft_model(unet, unet_lora_config) if config.train_text_encoder: text_encoder_lora_config = peft.LoraConfig( r=config.lora_rank_dim, lora_alpha=1.0, # init_lora_weights="gaussian", target_modules=TEXT_ENCODER_TARGET_MODULES, ) text_encoder = peft.get_peft_model(text_encoder, text_encoder_lora_config) def prep_peft_model(model, lr: float | None = None): if not isinstance(model, peft.PeftModel): return False model.print_trainable_parameters() # Populate `trainable_param_groups`, to be passed to the optimizer. param_group = {"params": list(filter(lambda p: p.requires_grad, model.parameters()))} if lr is not None: param_group["lr"] = lr trainable_param_groups.append(param_group) # Populate all_trainable_models. all_trainable_models.append(model) model.train() return True training_unet = prep_peft_model(unet, config.unet_learning_rate) training_text_encoder = prep_peft_model(text_encoder, config.text_encoder_learning_rate) # If mixed_precision is enabled, cast all trainable params to float32. if config.mixed_precision != "no": for trainable_model in all_trainable_models: for param in trainable_model.parameters(): if param.requires_grad: param.data = param.to(torch.float32) if config.gradient_checkpointing: # We want to enable gradient checkpointing in the UNet regardless of whether it is being trained. unet.enable_gradient_checkpointing() # unet must be in train() mode for gradient checkpointing to take effect. # At the time of writing, the unet dropout probabilities default to 0, so putting the unet in train mode does # not change its forward behavior. unet.train() if training_text_encoder: text_encoder.gradient_checkpointing_enable() # The text encoder must be in train() mode for gradient checkpointing to take effect. This should # already be the case, since we are training the text_encoder, but we do it explicitly to make it clear # that this is required. # At the time of writing, the text encoder dropout probabilities default to 0, so putting the text # encoders in train mode does not change their forward behavior. text_encoder.train() # Set requires_grad = True on the first parameters of the text encoders. Without this, the text encoder # LoRA weights would have 0 gradients, and so would not get trained. Note that the set of # trainable_param_groups has already been populated - the embeddings will not be trained. text_encoder.text_model.embeddings.requires_grad_(True) optimizer = initialize_optimizer(config.optimizer, trainable_param_groups) data_loader = build_image_pair_preference_sd_dataloader( config=config.data_loader, batch_size=config.train_batch_size, text_encoder_output_cache_dir=text_encoder_output_cache_dir_name, text_encoder_cache_field_to_output_field={"text_encoder_output": "text_encoder_output"}, vae_output_cache_dir=vae_output_cache_dir_name, shuffle=True, ) # TODO(ryand): Test in a distributed training environment and more clearly document the rationale for scaling steps # by the number of processes. This scaling logic was copied from the diffusers example training code, but it appears # in many places so I don't know where it originated. Internally, accelerate makes one LR scheduler step per process # (https://github.com/huggingface/accelerate/blame/49cb83a423f2946059117d8bb39b7c8747d29d80/src/accelerate/scheduler.py#L72-L82), # so the scaling here simply reverses that behaviour. lr_scheduler: torch.optim.lr_scheduler.LRScheduler = get_scheduler( config.lr_scheduler, optimizer=optimizer, num_warmup_steps=config.lr_warmup_steps * accelerator.num_processes, num_training_steps=config.max_train_steps * accelerator.num_processes, ) prepared_result: tuple[ UNet2DConditionModel | peft.PeftModel, CLIPTextModel | peft.PeftModel, torch.optim.Optimizer, torch.utils.data.DataLoader, torch.optim.lr_scheduler.LRScheduler, ] = accelerator.prepare( unet, text_encoder, optimizer, data_loader, lr_scheduler, # Disable automatic device placement for text_encoder if the text encoder outputs were cached. device_placement=[True, not config.cache_text_encoder_outputs, True, True, True], ) unet, text_encoder, optimizer, data_loader, lr_scheduler = prepared_result # Calculate the number of epochs and total training steps. A "step" represents a single weight update operation # (i.e. takes into account gradient accumulation steps). # math.ceil(...) is used in calculating the num_steps_per_epoch, because by default an optimizer step is taken when # the end of the dataloader is reached, even if gradient_accumulation_steps hasn't been reached. num_steps_per_epoch = math.ceil(len(data_loader) / config.gradient_accumulation_steps) num_train_epochs = math.ceil(config.max_train_steps / num_steps_per_epoch) if accelerator.is_main_process: accelerator.init_trackers("lora_training") # Tensorboard uses markdown formatting, so we wrap the config json in a code block. accelerator.log({"configuration": f"```json\n{json.dumps(config.dict(), indent=2, default=str)}\n```\n"}) checkpoint_tracker = CheckpointTracker( base_dir=ckpt_dir, prefix="checkpoint", extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None, max_checkpoints=config.max_checkpoints, ) # Train! total_batch_size = config.train_batch_size * accelerator.num_processes * config.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num examples = {len(data_loader)}") logger.info(f" Instantaneous batch size per device = {config.train_batch_size}") logger.info(f" Gradient accumulation steps = {config.gradient_accumulation_steps}") logger.info(f" Parallel processes = {accelerator.num_processes}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Total optimization steps = {config.max_train_steps}") global_step = 0 first_epoch = 0 completed_epochs = first_epoch progress_bar = tqdm( range(global_step, config.max_train_steps), # Only show the progress bar once on each machine. disable=not accelerator.is_local_main_process, ) progress_bar.set_description("Steps") for epoch in range(first_epoch, num_train_epochs): train_loss = 0.0 for data_batch_idx, data_batch in enumerate(data_loader): with accelerator.accumulate(unet, text_encoder): loss = train_forward_dpo( config=config, data_batch=data_batch, vae=vae, noise_scheduler=noise_scheduler, tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, ref_text_encoder=ref_text_encoder, ref_unet=ref_unet, weight_dtype=weight_dtype, ) # Gather the losses across all processes for logging (if we use distributed training). # TODO(ryand): Test that this works properly with distributed training. avg_loss = accelerator.gather(loss.repeat(config.train_batch_size)).mean() train_loss += avg_loss.item() / config.gradient_accumulation_steps # Backpropagate. accelerator.backward(loss) if accelerator.sync_gradients and config.max_grad_norm is not None: params_to_clip = itertools.chain.from_iterable([m.parameters() for m in all_trainable_models]) accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) # Checks if the accelerator has performed an optimization step behind the scenes. if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 completed_epochs = epoch if (data_batch_idx + 1) < len(data_loader) else epoch + 1 log = {"train_loss": train_loss} lrs = lr_scheduler.get_last_lr() if training_unet: # When training the UNet, it will always be the first parameter group. log["lr/unet"] = float(lrs[0]) if config.optimizer.optimizer_type == "Prodigy": log["lr/d*lr/unet"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] if training_text_encoder: # When training the text encoder, it will always be the last parameter group. log["lr/text_encoder"] = float(lrs[-1]) if config.optimizer.optimizer_type == "Prodigy": log["lr/d*lr/text_encoder"] = optimizer.param_groups[-1]["d"] * optimizer.param_groups[-1]["lr"] accelerator.log(log, step=global_step) train_loss = 0.0 # global_step represents the *number of completed steps* at this point. if config.save_every_n_steps is not None and global_step % config.save_every_n_steps == 0: accelerator.wait_for_everyone() if accelerator.is_main_process: _save_sd_lora_checkpoint( epoch=completed_epochs, step=global_step, unet=accelerator.unwrap_model(unet) if training_unet else None, text_encoder=accelerator.unwrap_model(text_encoder) if training_text_encoder else None, logger=logger, checkpoint_tracker=checkpoint_tracker, lora_checkpoint_format=config.lora_checkpoint_format, ) logs = { "step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], } progress_bar.set_postfix(**logs) if global_step >= config.max_train_steps: break # Save a checkpoint every n epochs. if config.save_every_n_epochs is not None and completed_epochs % config.save_every_n_epochs == 0: if accelerator.is_main_process: accelerator.wait_for_everyone() _save_sd_lora_checkpoint( epoch=completed_epochs, step=global_step, unet=accelerator.unwrap_model(unet) if training_unet else None, text_encoder=accelerator.unwrap_model(text_encoder) if training_text_encoder else None, logger=logger, checkpoint_tracker=checkpoint_tracker, lora_checkpoint_format=config.lora_checkpoint_format, ) # Generate validation images every n epochs. if len(config.validation_prompts) > 0 and completed_epochs % config.validate_every_n_epochs == 0: if accelerator.is_main_process: generate_validation_images_sd( epoch=completed_epochs, step=global_step, out_dir=out_dir, accelerator=accelerator, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, noise_scheduler=noise_scheduler, unet=unet, config=config, logger=logger, ) accelerator.end_training() ================================================ FILE: src/invoke_training/pipelines/callbacks.py ================================================ from abc import ABC from enum import Enum class ModelType(Enum): # At first glance, it feels like these model types should be further broken down into separate enums (e.g. # base_model, model_type, checkpoint_format). But, I haven't yet come up with a taxonomy that feels sufficiently # future-proof. So, for now, there is one enum for each file type that invoke-training can produce. # A Flux LoRA model in PEFT format. FLUX_LORA_PEFT = "FLUX_LORA_PEFT" # A Flux LoRA model in Kohya format. FLUX_LORA_KOHYA = "FLUX_LORA_KOHYA" # A Stable Diffusion 1.x LoRA model in Kohya format. SD1_LORA_KOHYA = "SD1_LORA_KOHYA" # A Stable Diffusion 1.x LoRA model in PEFT format. SD1_LORA_PEFT = "SD1_LORA_PEFT" # A Stable Diffusion XL LoRA model in Kohya format. SDXL_LORA_KOHYA = "SDXL_LORA_KOHYA" # A Stable Diffusion XL LoRA model in PEFT format. SDXL_LORA_PEFT = "SDXL_LORA_PEFT" # A Stable Diffusion 1.x Textual Inversion model. SD1_TEXTUAL_INVERSION = "SD1_TEXTUAL_INVERSION" # A Stable Diffusion XL Textual Inversion model. SDXL_TEXTUAL_INVERSION = "SDXL_TEXTUAL_INVERSION" # A Stable Diffusion 1.x UNet checkpoint in diffusers format. SD1_UNET_DIFFUSERS = "SD1_UNET_DIFFUSERS" # A Stable Diffusion XL UNet checkpoint in diffusers format. SDXL_UNET_DIFFUSERS = "SDXL_UNET_DIFFUSERS" # A full Stable Diffusion XL checkpoint in diffusers format. SDXL_FULL_DIFFUSERS = "SDXL_FULL_DIFFUSERS" class ModelCheckpoint: """A single model checkpoint.""" def __init__(self, file_path: str, model_type: ModelType): self.file_path = file_path self.model_type = model_type class TrainingCheckpoint: """A training checkpoint. May contain multiple model checkpoints if multiple models are being trained simultaneously. """ def __init__(self, models: list[ModelCheckpoint], epoch: int, step: int): self.models = models self.epoch = epoch self.step = step class ValidationImage: def __init__(self, file_path: str, prompt: str, image_idx: int): """A single validation image. Args: file_path (str): Path to the image file. prompt (str): The prompt used to generate the image. image_idx (int): The index of this image in the current validation set (i.e. in the set of images generated with the same prompt at the same validation point). """ self.file_path = file_path self.prompt = prompt self.image_idx = image_idx class ValidationImages: def __init__(self, images: list[ValidationImage], epoch: int, step: int): """A collection of validation images. Args: images (list[ValidationImage]): The validation images. epoch (int): The last completed epoch at the time that these images were generated. step (int): The last completed training step at the time that these images were generated. """ self.images = images self.epoch = epoch self.step = step class PipelineCallbacks(ABC): def on_save_checkpoint(self, checkpoint: TrainingCheckpoint): pass def on_save_validation_images(self, images: ValidationImages): pass ================================================ FILE: src/invoke_training/pipelines/flux/lora/__init__.py ================================================ ================================================ FILE: src/invoke_training/pipelines/flux/lora/config.py ================================================ from typing import Annotated, Literal, Union from pydantic import Field from invoke_training._shared.flux.lora_checkpoint_utils import ( FLUX_TRANSFORMER_TARGET_MODULES, TEXT_ENCODER_TARGET_MODULES, ) from invoke_training.config.base_pipeline_config import BasePipelineConfig from invoke_training.config.data.data_loader_config import ImageCaptionFluxDataLoaderConfig from invoke_training.config.optimizer.optimizer_config import ( AdamOptimizerConfig, ProdigyOptimizerConfig, ) class FluxLoraConfig(BasePipelineConfig): type: Literal["FLUX_LORA"] = "FLUX_LORA" model: str = "black-forest-labs/FLUX.1-dev" """Name or path of the base model to train. Can be in diffusers format, or a single Flux.1-dev checkpoint file. (E.g. 'black-forest-labs/FLUX.1-dev', '/path/to/flux.1-dev.safetensors', etc. ) """ transformer_path: str | None = None """Path to the custom transformer .safetensors file. If not provided, the default black-forest-labs/FLUX.1-dev transformer will be used. """ text_encoder_1_path: str | None = None """Path to the custom CLIP text encoder .safetensors file. If not provided, the default openai/clip-vit-base-patch32 text encoder will be used. """ text_encoder_2_path: str | None = None """Path to the custom T5 text encoder .safetensors file. If not provided, the default google/t5-v1_1-xl text encoder will be used. """ lora_checkpoint_format: Literal["invoke_peft", "kohya"] = "kohya" """The format of the LoRA checkpoint to save. Choose between `invoke_peft` or `kohya`.""" train_transformer: bool = True """Whether to add LoRA layers to the FluxTransformer2DModel and train it. """ train_text_encoder: bool = False """Whether to add LoRA layers to the text encoder and train it. """ optimizer: AdamOptimizerConfig | ProdigyOptimizerConfig = AdamOptimizerConfig() text_encoder_learning_rate: float | None = 1e-4 """The learning rate to use for the text encoder model. If set, this overrides the optimizer's default learning rate. Set to null or 0 to use the optimizer's default learning rate. """ transformer_learning_rate: float | None = 4e-4 """The learning rate to use for the transformer model. If set, this overrides the optimizer's default learning rate. Set to null or 0 to use the optimizer's default learning rate. """ lr_scheduler: Literal[ "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup" ] = "constant_with_warmup" lr_warmup_steps: int = 10 """The number of warmup steps in the learning rate scheduler. Only applied to schedulers that support warmup. See lr_scheduler. """ min_snr_gamma: float | None = None """Min-SNR weighting for diffusion training was introduced in https://arxiv.org/abs/2303.09556. This strategy improves the speed of training convergence by adjusting the weight of each sample. `min_snr_gamma` acts like an an upper bound on the weight of samples with low noise levels. If `None`, then Min-SNR weighting will not be applied. If enabled, the recommended value is `min_snr_gamma = 5.0`. """ lora_rank_dim: int = 4 """The rank dimension to use for the LoRA layers. Increasing the rank dimension increases the model's expressivity, but also increases the size of the generated LoRA model. """ flux_lora_target_modules: list[str] = FLUX_TRANSFORMER_TARGET_MODULES """The list of target modules to apply LoRA layers to in the FluxTransformer2DModel. The default list will produce a highly expressive LoRA model. For a smaller and less expressive LoRA model, the following list is recommended: ```python flux_lora_target_modules = ["to_k", "to_q", "to_v", "to_out.0"] ``` The list of target modules is passed to Hugging Face's PEFT library. See [the docs](https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.target_modules) for details. """ text_encoder_lora_target_modules: list[str] = TEXT_ENCODER_TARGET_MODULES """The list of target modules to apply LoRA layers to in the CLIP text encoder. The default list will produce a highly expressive LoRA model. For a smaller and less expressive LoRA model, the following list is recommended: ```python text_encoder_lora_target_modules = ["fc1", "fc2", "q_proj", "k_proj", "v_proj", "out_proj"] ``` The list of target modules is passed to Hugging Face's PEFT library. See [the docs](https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.target_modules) for details. """ cache_text_encoder_outputs: bool = False """If True, the text encoder(s) will be applied to all of the captions in the dataset before starting training and the results will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the text encoders in VRAM), and speeds up training (don't have to run the text encoders for each training example). This option can only be enabled if `train_text_encoder == False` and there are no caption augmentations being applied. """ cache_vae_outputs: bool = False """If True, the VAE will be applied to all of the images in the dataset before starting training and the results will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the VAE in VRAM), and speeds up training (don't have to run the VAE encoding step). This option can only be enabled if all non-deterministic image augmentations are disabled (i.e. center_crop=True, random_flip=False). """ enable_cpu_offload_during_validation: bool = False """If True, models will be kept in CPU memory and loaded into GPU memory one-by-one while generating validation images. This reduces VRAM requirements at the cost of slower generation of validation images. """ gradient_accumulation_steps: int = 1 """The number of gradient steps to accumulate before each weight update. This value is passed to Hugging Face Accelerate. This is an alternative to increasing the batch size when training with limited VRAM. """ weight_dtype: Literal["float32", "float16", "bfloat16"] = "float16" """All weights (trainable and fixed) will be cast to this precision. Lower precision dtypes require less VRAM, and result in faster training, but are more prone to issues with numerical stability. Recommendations: - `"float32"`: Use this mode if you have plenty of VRAM available. - `"bfloat16"`: Use this mode if you have limited VRAM and a GPU that supports bfloat16. - `"float16"`: Use this mode if you have limited VRAM and a GPU that does not support bfloat16. See also [`mixed_precision`][invoke_training.pipelines.flux.lora.config.FluxLoraConfig.mixed_precision]. """ # noqa: E501 mixed_precision: Literal["no", "fp16", "bf16", "fp8"] = "no" """The mixed precision mode to use. If mixed precision is enabled, then all non-trainable parameters will be cast to the specified `weight_dtype`, and trainable parameters are kept in float32 precision to avoid issues with numerical stability. This value is passed to Hugging Face Accelerate. See [`accelerate.Accelerator.mixed_precision`](https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.mixed_precision) for more details. """ # noqa: E501 gradient_checkpointing: bool = False """Whether or not to use gradient checkpointing to save memory at the expense of a slower backward pass. Enabling gradient checkpointing slows down training by ~20%. """ max_checkpoints: int | None = None """The maximum number of checkpoints to keep. New checkpoints will replace earlier checkpoints to stay under this limit. Note that this limit is applied to 'step' and 'epoch' checkpoints separately. """ prediction_type: Literal["epsilon", "v_prediction"] | None = None """The prediction_type that will be used for training. Choose between 'epsilon' or 'v_prediction' or leave 'None'. If 'None', the prediction type of the scheduler: `noise_scheduler.config.prediction_type` is used. """ max_grad_norm: float | None = None """Max gradient norm for clipping. Set to null or 0 for no clipping. """ validation_prompts: list[str] = [] """A list of prompts that will be used to generate images throughout training for the purpose of tracking progress. See also 'validate_every_n_epochs'. """ num_validation_images_per_prompt: int = 4 """The number of validation images to generate for each prompt in 'validation_prompts'. Careful, validation can become quite slow if this number is too large. """ train_batch_size: int = 1 """The training batch size. """ use_masks: bool = False """If True, image masks will be applied to weight the loss during training. The dataset must contain masks for this feature to be used. """ data_loader: Annotated[Union[ImageCaptionFluxDataLoaderConfig], Field(discriminator="type")] timestep_sampler: Literal["shift", "uniform"] = "shift" """The timestep sampler to use. Choose between 'shift' or 'uniform'.""" discrete_flow_shift: float = 3.0 """The shift parameter for the discrete flow. Only used if `timestep_sampler == "shift"`. """ sigmoid_scale: float = 1.0 """The scale parameter for the sigmoid function. Only used if `timestep_sampler == "shift"`. """ lora_scale: float | None = 1.0 """The scale parameter for the LoRA layers. If set, this overrides the optimizer's default learning rate. """ guidance_scale: float = 1.0 """The guidance scale for the Flux model. """ train_transformer: bool = True """Whether to train the Flux transformer (FluxTransformer2DModel) model. """ clip_tokenizer_max_length: int = 77 """The maximum length of the CLIP tokenizer. The maximum length of the CLIP tokenizer is 77. """ t5_tokenizer_max_length: int = 512 """The maximum length of the T5 tokenizer. The maximum length of the T5 tokenizer is 512. """ ================================================ FILE: src/invoke_training/pipelines/flux/lora/train.py ================================================ import itertools import json import logging import math import os import tempfile import time from pathlib import Path from typing import Literal, Optional, Union import numpy as np import peft import torch import torch.utils.data from accelerate.utils import set_seed from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxTransformer2DModel from diffusers.optimization import get_scheduler from diffusers.pipelines.flux.pipeline_flux import FluxPipeline from peft import PeftModel from PIL import Image from torch.utils.data import DataLoader from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer from invoke_training._shared.accelerator.accelerator_utils import ( get_dtype_from_str, initialize_accelerator, initialize_logging, ) from invoke_training._shared.checkpoints.checkpoint_tracker import CheckpointTracker from invoke_training._shared.data.data_loaders.image_caption_flux_dataloader import build_image_caption_flux_dataloader from invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache from invoke_training._shared.flux.encoding_utils import encode_prompt from invoke_training._shared.flux.lora_checkpoint_utils import ( save_flux_kohya_checkpoint, save_flux_peft_checkpoint, ) from invoke_training._shared.flux.model_loading_utils import load_models_flux from invoke_training._shared.flux.validation import generate_validation_images_flux from invoke_training._shared.optimizer.optimizer_utils import initialize_optimizer from invoke_training._shared.stable_diffusion.tokenize_captions import tokenize_captions from invoke_training.config.data.data_loader_config import ImageCaptionSDDataLoaderConfig from invoke_training.pipelines.callbacks import ModelCheckpoint, ModelType, PipelineCallbacks, TrainingCheckpoint from invoke_training.pipelines.flux.lora.config import FluxLoraConfig def _save_flux_lora_checkpoint( epoch: int, step: int, transformer: peft.PeftModel | None, text_encoder_1: CLIPTextModel | None, text_encoder_2: T5EncoderModel | None, logger: logging.Logger, checkpoint_tracker: CheckpointTracker, callbacks: list[PipelineCallbacks] | None, lora_checkpoint_format: Literal["invoke_peft", "kohya"] = "invoke_peft", ): # Prune checkpoints and get new checkpoint path. num_pruned = checkpoint_tracker.prune(1) if num_pruned > 0: logger.info(f"Pruned {num_pruned} checkpoint(s).") save_path = checkpoint_tracker.get_path(epoch=epoch, step=step) if lora_checkpoint_format == "invoke_peft": model_type = ModelType.FLUX_LORA_PEFT save_flux_peft_checkpoint( Path(save_path), transformer=transformer, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2 ) elif lora_checkpoint_format == "kohya": model_type = ModelType.FLUX_LORA_KOHYA save_flux_kohya_checkpoint( Path(save_path), transformer=transformer, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2 ) else: raise ValueError(f"Unsupported lora_checkpoint_format: '{lora_checkpoint_format}'.") if callbacks is not None: for cb in callbacks: cb.on_save_checkpoint( TrainingCheckpoint( models=[ModelCheckpoint(file_path=save_path, model_type=model_type)], epoch=epoch, step=step ) ) def _build_data_loader( data_loader_config: Union[ImageCaptionSDDataLoaderConfig], batch_size: int, use_masks: bool = False, text_encoder_output_cache_dir: Optional[str] = None, vae_output_cache_dir: Optional[str] = None, shuffle: bool = True, sequential_batching: bool = False, ) -> DataLoader: if data_loader_config.type == "IMAGE_CAPTION_FLUX_DATA_LOADER": return build_image_caption_flux_dataloader( config=data_loader_config, batch_size=batch_size, use_masks=use_masks, text_encoder_output_cache_dir=text_encoder_output_cache_dir, text_encoder_cache_field_to_output_field={"text_encoder_output": "text_encoder_output"}, vae_output_cache_dir=vae_output_cache_dir, shuffle=shuffle, ) else: raise ValueError(f"Unsupported data loader config type: '{data_loader_config.type}'.") def cache_text_encoder_outputs( cache_dir: str, config: FluxLoraConfig, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel ): """Run the text encoder on all captions in the dataset and cache the results to disk. Args: cache_dir (str): The directory where the results will be cached. config (FluxLoraConfig): Training config. tokenizer (CLIPTokenizer): The tokenizer. text_encoder (CLIPTextModel): The text_encoder. """ data_loader = _build_data_loader( data_loader_config=config.data_loader, batch_size=config.train_batch_size, shuffle=False, sequential_batching=True, ) cache = TensorDiskCache(cache_dir) for data_batch in tqdm(data_loader): caption_token_ids = tokenize_captions(tokenizer, data_batch["caption"]).to(text_encoder.device) text_encoder_output_batch = text_encoder(caption_token_ids)[0] # Split batch before caching. for i in range(len(data_batch["id"])): cache.save(data_batch["id"][i], {"text_encoder_output": text_encoder_output_batch[i]}) def cache_vae_outputs(cache_dir: str, data_loader: DataLoader, vae: AutoencoderKL): """Run the VAE on all images in the dataset and cache the results to disk.""" cache = TensorDiskCache(cache_dir) for data_batch in tqdm(data_loader): latents = vae.encode(data_batch["image"].to(device=vae.device, dtype=vae.dtype)).latent_dist.sample() latents = latents * vae.config.scaling_factor # Split batch before caching. for i in range(len(data_batch["id"])): data = { "vae_output": latents[i], "original_size_hw": data_batch["original_size_hw"][i], "crop_top_left_yx": data_batch["crop_top_left_yx"][i], } if "mask" in data_batch: data["mask"] = data_batch["mask"][i] cache.save(data_batch["id"][i], data) def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32): sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) schedule_timesteps = noise_scheduler.timesteps.to(device) timesteps = timesteps.to(device) step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < n_dim: sigma = sigma.unsqueeze(-1) return sigma def get_noisy_latents(noise_scheduler: FlowMatchEulerDiscreteScheduler, latents: torch.Tensor, config: FluxLoraConfig): """ Generate random noise. Sample a random timestep from the distribution chosen by the config. Linearly interpolate between the latents and the noise based on timestep. See Section 3.1 of https://arxiv.org/pdf/2403.03206v1 for timestep sampling. Args: noise_scheduler (FlowMatchEulerDiscreteScheduler): The noise scheduler. latents (torch.Tensor): The latents. config (FluxLoraConfig): The config. Returns: torch.Tensor: The noisy latents. """ batch_size = latents.shape[0] dtype = latents.dtype device = latents.device noise = torch.randn_like(latents) if config.timestep_sampler == "shift": shift = config.discrete_flow_shift sigmas = torch.randn(batch_size, device=device) sigmas = sigmas * config.sigmoid_scale # larger scale for more uniform sampling sigmas = sigmas.sigmoid() sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas) timesteps = sigmas * noise_scheduler.config.num_train_timesteps else: u = torch.rand(size=(batch_size,), device="cpu") indices = (u * noise_scheduler.config.num_train_timesteps).long() timesteps = noise_scheduler.timesteps[indices].to(device=device) sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) sigmas = sigmas.view(-1, 1, 1, 1) # Linearly interpolate between the latents and the noise. noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise return noisy_model_input.to(dtype), noise.to(dtype), timesteps.to(dtype), sigmas.to(dtype) def decode_latents(vae: AutoencoderKL, latents: torch.Tensor): latents = latents / vae.config.scaling_factor image = vae.decode(latents).sample # tensor to image image = image.cpu().numpy() image = (image * 255).astype(np.uint8) image = Image.fromarray(image) image.save("image.png") return image def train_forward( # noqa: C901 config: FluxLoraConfig, data_batch: dict, vae: AutoencoderKL, noise_scheduler: FlowMatchEulerDiscreteScheduler, tokenizer_1: CLIPTokenizer, tokenizer_2: T5Tokenizer, text_encoder_1: CLIPTextModel, text_encoder_2: T5EncoderModel, transformer: FluxTransformer2DModel | PeftModel, weight_dtype: torch.dtype, use_masks: bool = False, min_snr_gamma: float | None = None, logger: logging.Logger = None, ) -> torch.Tensor: """Run the forward training pass for a single data_batch. Returns: torch.Tensor: Loss """ # Convert images to latent space. # The VAE output may have been cached and included in the data_batch. If not, we calculate it here. latents = data_batch.get("vae_output", None) if latents is None: # Cast input image to same dtype as VAE image = data_batch["image"].to(device=vae.device, dtype=vae.dtype) latents = vae.encode(image).latent_dist.sample() batch_size, num_channels, height, width = latents.shape latents = latents * vae.config.scaling_factor latents = FluxPipeline._pack_latents(latents, batch_size, num_channels, height, width) else: batch_size, num_channels, height, width = latents.shape # Sample noise that we'll add to the latents. latent_image_ids = FluxPipeline._prepare_latent_image_ids( batch_size, height // 2, width // 2, latents.device, latents.dtype ) # Add noise to the latents according to the noise magnitude at each timestep (this is the forward # diffusion process). noisy_latents, noise, timesteps, sigmas = get_noisy_latents(noise_scheduler, latents, config) # Get the text embedding for conditioning. # The text encoder output may have been cached and included in the data_batch. If not, we calculate it here. if "prompt_embeds" in data_batch: prompt_embeds = data_batch["prompt_embeds"] pooled_prompt_embeds = data_batch["pooled_prompt_embeds"] else: prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( prompt=data_batch["caption"], prompt_2=data_batch.get("caption_2", None), clip_tokenizer=tokenizer_1, t5_tokenizer=tokenizer_2, clip_text_encoder=text_encoder_1, t5_text_encoder=text_encoder_2, device=latents.device, num_images_per_prompt=1, lora_scale=config.lora_scale, clip_tokenizer_max_length=config.clip_tokenizer_max_length, t5_tokenizer_max_length=config.t5_tokenizer_max_length, logger=logger, ) guidance = torch.full((batch_size,), float(config.guidance_scale), device=latents.device) model_pred = transformer( hidden_states=noisy_latents[0], timestep=timesteps / 1000, pooled_projections=pooled_prompt_embeds, encoder_hidden_states=prompt_embeds, guidance=guidance, txt_ids=text_ids, img_ids=latent_image_ids, return_dict=False, )[0] ### Flow matching loss # See here for more discussion:https://discuss.huggingface.co/t/meaning-of-vector-fields-in-flux-and-sd3-loss-function/106601 target = noise - latents loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) return loss.mean() def train(config: FluxLoraConfig, callbacks: list[PipelineCallbacks] | None = None): # noqa: C901 # Create a timestamped directory for all outputs. out_dir = os.path.join(config.base_output_dir, f"{time.time()}") ckpt_dir = os.path.join(out_dir, "checkpoints") os.makedirs(ckpt_dir) accelerator = initialize_accelerator( out_dir, config.gradient_accumulation_steps, config.mixed_precision, config.report_to ) logger = initialize_logging(os.path.basename(__file__), accelerator) # Set the accelerate seed. if config.seed is not None: set_seed(config.seed) # Log the accelerator configuration from every process to help with debugging. logger.info(accelerator.state, main_process_only=False) logger.info("Starting LoRA Training.") logger.info(f"Configuration:\n{json.dumps(config.dict(), indent=2, default=str)}") logger.info(f"Output dir: '{out_dir}'") # Write the configuration to disk. with open(os.path.join(out_dir, "config.json"), "w") as f: json.dump(config.dict(), f, indent=2, default=str) weight_dtype = get_dtype_from_str(config.weight_dtype) logger.info("Loading models.") tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, transformer = load_models_flux( model_name_or_path=config.model, transformer_path=config.transformer_path, text_encoder_1_path=config.text_encoder_1_path, text_encoder_2_path=config.text_encoder_2_path, dtype=weight_dtype, logger=logger, ) # Prepare text encoder output cache. text_encoder_output_cache_dir_name = None if config.cache_text_encoder_outputs: # TODO(ryand): Think about how to better check if it is safe to cache the text encoder outputs. Currently, there # are a number of configurations that would cause variation in the text encoder outputs and should not be used # with caching. # We use a temporary directory for the cache. The directory will automatically be cleaned up when # tmp_text_encoder_output_cache_dir is destroyed. tmp_text_encoder_output_cache_dir = tempfile.TemporaryDirectory() text_encoder_output_cache_dir_name = tmp_text_encoder_output_cache_dir.name if accelerator.is_local_main_process: # Only the main process should populate the cache. logger.info(f"Generating text encoder output cache ('{text_encoder_output_cache_dir_name}').") text_encoder_1.to(accelerator.device, dtype=weight_dtype) text_encoder_2.to(accelerator.device, dtype=weight_dtype) # TODO(ryan): Move cache_text_encoder_outputs to a shared location so that it is not imported from another # pipeline. cache_text_encoder_outputs( text_encoder_output_cache_dir_name, config, tokenizer_1, tokenizer_2, text_encoder_1, text_encoder_2 ) # Move the text_encoders back to the CPU, because they are not needed for training. text_encoder_1.to("cpu") text_encoder_2.to("cpu") accelerator.wait_for_everyone() else: text_encoder_1.to(accelerator.device, dtype=weight_dtype) text_encoder_2.to(accelerator.device, dtype=weight_dtype) # Prepare VAE output cache. # vae_output_cache_dir_name = None if config.cache_vae_outputs: if config.data_loader.random_flip: raise ValueError("'cache_vae_outputs' cannot be True if 'random_flip' is True.") if not config.data_loader.center_crop: raise ValueError("'cache_vae_outputs' cannot be True if 'center_crop' is False.") # We use a temporary directory for the cache. The directory will automatically be cleaned up when # tmp_vae_output_cache_dir is destroyed. tmp_vae_output_cache_dir = tempfile.TemporaryDirectory() vae_output_cache_dir_name = tmp_vae_output_cache_dir.name if accelerator.is_local_main_process: # Only the main process should populate the cache. logger.info(f"Generating VAE output cache ('{vae_output_cache_dir_name}').") vae.to(accelerator.device, dtype=weight_dtype) data_loader = _build_data_loader( data_loader_config=config.data_loader, batch_size=config.train_batch_size, shuffle=False, sequential_batching=True, ) cache_vae_outputs(vae_output_cache_dir_name, data_loader, vae) # Move the VAE back to the CPU, because it is not needed for training. vae.to("cpu") accelerator.wait_for_everyone() else: vae.to(accelerator.device, dtype=weight_dtype) transformer.to(accelerator.device, dtype=weight_dtype) # Add LoRA layers to the models being trained. trainable_param_groups = [] all_trainable_models: list[peft.PeftModel] = [] def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = None) -> peft.PeftModel: peft_model = peft.get_peft_model(model, lora_config) peft_model.print_trainable_parameters() # Populate `trainable_param_groups`, to be passed to the optimizer. param_group = {"params": list(filter(lambda p: p.requires_grad, peft_model.parameters()))} if lr is not None: param_group["lr"] = lr trainable_param_groups.append(param_group) # Populate all_trainable_models. all_trainable_models.append(peft_model) peft_model.train() return peft_model # Add LoRA layers to the model. if config.train_transformer: transformer_lora_config = peft.LoraConfig( r=config.lora_rank_dim, # TODO(ryand): Diffusers uses lora_alpha=config.lora_rank_dim. Is that preferred? lora_alpha=1.0, target_modules=config.flux_lora_target_modules, ) transformer = inject_lora_layers(transformer, transformer_lora_config, lr=config.transformer_learning_rate) if config.train_text_encoder: text_encoder_lora_config = peft.LoraConfig( r=config.lora_rank_dim, lora_alpha=1.0, # init_lora_weights="gaussian", target_modules=config.text_encoder_lora_target_modules, ) text_encoder_1 = inject_lora_layers( text_encoder_1, text_encoder_lora_config, lr=config.text_encoder_learning_rate ) # Enable gradient checkpointing. if config.gradient_checkpointing: # We want to enable gradient checkpointing in the UNet regardless of whether it is being trained. transformer.enable_gradient_checkpointing() # unet must be in train() mode for gradient checkpointing to take effect. # At the time of writing, the unet dropout probabilities default to 0, so putting the unet in train mode does # not change its forward behavior. transformer.train() if config.train_text_encoder: text_encoder_1.gradient_checkpointing_enable() # The text encoders must be in train() mode for gradient checkpointing to take effect. This should # already be the case, since we are training the text_encoders, be we do it explicitly to make it clear # that this is required. # At the time of writing, the text encoder dropout probabilities default to 0, so putting the text # encoders in train mode does not change their forward behavior. text_encoder_1.train() # Set requires_grad = True on the first parameters of the text encoders. Without this, the text encoder # LoRA weights would have 0 gradients, and so would not get trained. Note that the set of # trainable_param_groups has already been populated - the embeddings will not be trained. text_encoder_1.text_model.embeddings.requires_grad_(True) optimizer = initialize_optimizer(config.optimizer, trainable_param_groups) data_loader = _build_data_loader( data_loader_config=config.data_loader, batch_size=config.train_batch_size, # text_encoder_output_cache_dir=text_encoder_output_cache_dir_name, # vae_output_cache_dir=vae_output_cache_dir_name, ) assert sum([config.max_train_steps is not None, config.max_train_epochs is not None]) == 1 assert sum([config.save_every_n_steps is not None, config.save_every_n_epochs is not None]) == 1 assert sum([config.validate_every_n_steps is not None, config.validate_every_n_epochs is not None]) == 1 # A "step" represents a single weight update operation (i.e. takes into account gradient accumulation steps). # math.ceil(...) is used in calculating the num_steps_per_epoch, because by default an optimizer step is taken when # the end of the dataloader is reached, even if gradient_accumulation_steps hasn't been reached. num_steps_per_epoch = math.ceil(len(data_loader) / config.gradient_accumulation_steps) num_train_steps = config.max_train_steps or config.max_train_epochs * num_steps_per_epoch num_train_epochs = math.ceil(num_train_steps / num_steps_per_epoch) # TODO(ryand): Test in a distributed training environment and more clearly document the rationale for scaling steps # by the number of processes. This scaling logic was copied from the diffusers example training code, but it appears # in many places so I don't know where it originated. Internally, accelerate makes one LR scheduler step per process # (https://github.com/huggingface/accelerate/blame/49cb83a423f2946059117d8bb39b7c8747d29d80/src/accelerate/scheduler.py#L72-L82), # so the scaling here simply reverses that behaviour. lr_scheduler: torch.optim.lr_scheduler.LRScheduler = get_scheduler( config.lr_scheduler, optimizer=optimizer, num_warmup_steps=config.lr_warmup_steps * accelerator.num_processes, num_training_steps=num_train_steps * accelerator.num_processes, ) prepared_result: tuple[ FluxTransformer2DModel, CLIPTextModel, T5EncoderModel, torch.optim.Optimizer, torch.utils.data.DataLoader, torch.optim.lr_scheduler.LRScheduler, ] = accelerator.prepare( transformer, text_encoder_1, text_encoder_2, optimizer, data_loader, lr_scheduler, # Disable automatic device placement for text_encoder if the text encoder outputs were cached. device_placement=[ True, not config.cache_text_encoder_outputs, not config.cache_text_encoder_outputs, True, True, True, ], ) transformer, text_encoder_1, text_encoder_2, optimizer, data_loader, lr_scheduler = prepared_result if accelerator.is_main_process: accelerator.init_trackers("lora_training") # Tensorboard uses markdown formatting, so we wrap the config json in a code block. accelerator.log({"configuration": f"```json\n{json.dumps(config.dict(), indent=2, default=str)}\n```\n"}) checkpoint_tracker = CheckpointTracker( base_dir=ckpt_dir, prefix="checkpoint", max_checkpoints=config.max_checkpoints, extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None, ) # Train! total_batch_size = config.train_batch_size * accelerator.num_processes * config.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num batches = {len(data_loader)}") logger.info(f" Instantaneous batch size per device = {config.train_batch_size}") logger.info(f" Gradient accumulation steps = {config.gradient_accumulation_steps}") logger.info(f" Parallel processes = {accelerator.num_processes}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Total optimization steps = {num_train_steps}") logger.info(f" Total epochs = {num_train_epochs}") global_step = 0 first_epoch = 0 completed_epochs = 0 progress_bar = tqdm( range(global_step, num_train_steps), # Only show the progress bar once on each machine. disable=not accelerator.is_local_main_process, ) progress_bar.set_description("Steps") def save_checkpoint(num_completed_epochs: int, num_completed_steps: int): accelerator.wait_for_everyone() if accelerator.is_main_process: _save_flux_lora_checkpoint( epoch=num_completed_epochs, step=num_completed_steps, transformer=transformer if config.train_transformer else None, text_encoder_1=text_encoder_1 if config.train_text_encoder else None, text_encoder_2=text_encoder_2 if config.train_text_encoder else None, logger=logger, checkpoint_tracker=checkpoint_tracker, lora_checkpoint_format=config.lora_checkpoint_format, callbacks=callbacks, ) accelerator.wait_for_everyone() def validate(num_completed_epochs: int, num_completed_steps: int): accelerator.wait_for_everyone() if accelerator.is_main_process: generate_validation_images_flux( epoch=num_completed_epochs, step=num_completed_steps, out_dir=out_dir, accelerator=accelerator, vae=vae, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2, tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2, noise_scheduler=noise_scheduler, transformer=transformer, config=config, logger=logger, callbacks=callbacks, ) accelerator.wait_for_everyone() for epoch in range(first_epoch, num_train_epochs): train_loss = 0.0 for data_batch_idx, data_batch in enumerate(data_loader): # (Pdb) data_batch['image'].shape # torch.Size([4, 3, 512, 512]) with accelerator.accumulate(transformer, text_encoder_1, text_encoder_2): loss = train_forward( config=config, data_batch=data_batch, vae=vae, noise_scheduler=noise_scheduler, tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2, transformer=transformer, weight_dtype=weight_dtype, min_snr_gamma=config.min_snr_gamma, ) # Gather the losses across all processes for logging (if we use distributed training). # TODO(ryand): Test that this works properly with distributed training. avg_loss = accelerator.gather(loss.repeat(config.train_batch_size)).mean() train_loss += avg_loss.item() / config.gradient_accumulation_steps # Backpropagate. accelerator.backward(loss) if accelerator.sync_gradients and config.max_grad_norm is not None: params_to_clip = itertools.chain.from_iterable([m.parameters() for m in all_trainable_models]) accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) # Checks if the accelerator has performed an optimization step behind the scenes. if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 completed_epochs = epoch if (data_batch_idx + 1) < len(data_loader) else epoch + 1 log = {"train_loss": train_loss} lrs = lr_scheduler.get_last_lr() if config.train_transformer: # When training the UNet, it will always be the first parameter group. log["lr/transformer"] = float(lrs[0]) if config.optimizer.optimizer_type == "Prodigy": log["lr/d*lr/transformer"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] if config.train_text_encoder: # When training the text encoder, it will always be the last parameter group. log["lr/text_encoder"] = float(lrs[-1]) if config.optimizer.optimizer_type == "Prodigy": log["lr/d*lr/text_encoder"] = optimizer.param_groups[-1]["d"] * optimizer.param_groups[-1]["lr"] accelerator.log(log, step=global_step) train_loss = 0.0 # global_step represents the *number of completed steps* at this point. if config.save_every_n_steps is not None and global_step % config.save_every_n_steps == 0: save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step) if ( config.validate_every_n_steps is not None and global_step % config.validate_every_n_steps == 0 and len(config.validation_prompts) > 0 ): validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step) logs = { "step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], } progress_bar.set_postfix(**logs) if global_step >= num_train_steps: break # Save a checkpoint every n epochs. if config.save_every_n_epochs is not None and completed_epochs % config.save_every_n_epochs == 0: save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step) # Generate validation images every n epochs. if ( config.validate_every_n_epochs is not None and completed_epochs % config.validate_every_n_epochs == 0 and len(config.validation_prompts) > 0 ): validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step) accelerator.end_training() ================================================ FILE: src/invoke_training/pipelines/invoke_train.py ================================================ import os from invoke_training.config.pipeline_config import PipelineConfig from invoke_training.pipelines._experimental.sd_dpo_lora.train import train as train_sd_ddpo_lora from invoke_training.pipelines.callbacks import PipelineCallbacks from invoke_training.pipelines.flux.lora.train import train as train_flux_lora from invoke_training.pipelines.stable_diffusion.lora.train import train as train_sd_lora from invoke_training.pipelines.stable_diffusion.textual_inversion.train import train as train_sd_ti from invoke_training.pipelines.stable_diffusion_xl.finetune.train import train as train_sdxl_finetune from invoke_training.pipelines.stable_diffusion_xl.lora.train import train as train_sdxl_lora from invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.train import ( train as train_sdxl_lora_and_ti, ) from invoke_training.pipelines.stable_diffusion_xl.textual_inversion.train import train as train_sdxl_ti def train(config: PipelineConfig, callbacks: list[PipelineCallbacks] | None = None): """This is the main entry point for all training pipelines.""" # Fail early if invalid callback types are provided, rather than failing later when the callbacks are used. for cb in callbacks or []: assert isinstance(cb, PipelineCallbacks) if config.type == "FLUX_LORA": # Disable tokenizer parallelism to avoid issues with tokenization os.environ["TOKENIZERS_PARALLELISM"] = "false" train_flux_lora(config, callbacks) elif config.type == "SD_LORA": train_sd_lora(config, callbacks) elif config.type == "SDXL_LORA": train_sdxl_lora(config, callbacks) elif config.type == "SD_TEXTUAL_INVERSION": train_sd_ti(config, callbacks) elif config.type == "SDXL_TEXTUAL_INVERSION": train_sdxl_ti(config, callbacks) elif config.type == "SDXL_LORA_AND_TEXTUAL_INVERSION": train_sdxl_lora_and_ti(config, callbacks) elif config.type == "SDXL_FINETUNE": train_sdxl_finetune(config, callbacks) elif config.type == "SD_DIRECT_PREFERENCE_OPTIMIZATION_LORA": print(f"Running EXPERIMENTAL pipeline: '{config.type}'.") train_sd_ddpo_lora(config, callbacks) else: raise ValueError(f"Unexpected pipeline type: '{config.type}'.") ================================================ FILE: src/invoke_training/pipelines/stable_diffusion/__init__.py ================================================ ================================================ FILE: src/invoke_training/pipelines/stable_diffusion/lora/__init__.py ================================================ ================================================ FILE: src/invoke_training/pipelines/stable_diffusion/lora/config.py ================================================ from typing import Annotated, Literal, Union from pydantic import Field, model_validator from invoke_training._shared.stable_diffusion.lora_checkpoint_utils import ( TEXT_ENCODER_TARGET_MODULES, UNET_TARGET_MODULES, ) from invoke_training.config.base_pipeline_config import BasePipelineConfig from invoke_training.config.data.data_loader_config import DreamboothSDDataLoaderConfig, ImageCaptionSDDataLoaderConfig from invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig, ProdigyOptimizerConfig class SdLoraConfig(BasePipelineConfig): type: Literal["SD_LORA"] = "SD_LORA" model: str = "runwayml/stable-diffusion-v1-5" """Name or path of the base model to train. Can be in diffusers format, or a single stable diffusion checkpoint file. (E.g. 'runwayml/stable-diffusion-v1-5', '/path/to/realisticVisionV51_v51VAE.safetensors', etc. ) """ hf_variant: str | None = "fp16" """The Hugging Face Hub model variant to use. Only applies if `model` is a Hugging Face Hub model name. """ # Note: Pydantic handles mutable default values well: # https://docs.pydantic.dev/latest/concepts/models/#fields-with-non-hashable-default-values base_embeddings: dict[str, str] = {} """A mapping of embedding tokens to trained embedding file paths. These embeddings will be applied to the base model before training. Example: ``` base_embeddings = { "bruce_the_gnome": "/path/to/bruce_the_gnome.safetensors", } ``` Consider also adding the embedding tokens to the `data_loader.caption_prefix` if they are not already present in the dataset captions. Note that the embeddings themselves are not fine-tuned further, but they will impact the LoRA model training if they are referenced in the dataset captions. The list of embeddings provided here should be the same list used at generation time with the resultant LoRA model. """ lora_checkpoint_format: Literal["invoke_peft", "kohya"] = "kohya" """The format of the LoRA checkpoint to save. Choose between `invoke_peft` or `kohya`.""" train_unet: bool = True """Whether to add LoRA layers to the UNet model and train it. """ train_text_encoder: bool = True """Whether to add LoRA layers to the text encoder and train it. """ optimizer: AdamOptimizerConfig | ProdigyOptimizerConfig = AdamOptimizerConfig() text_encoder_learning_rate: float | None = None """The learning rate to use for the text encoder model. If set, this overrides the optimizer's default learning rate. Set to null or 0 to use the optimizer's default learning rate. """ unet_learning_rate: float | None = None """The learning rate to use for the UNet model. If set, this overrides the optimizer's default learning rate. Set to null or 0 to use the optimizer's default learning rate. """ lr_scheduler: Literal[ "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup" ] = "constant" lr_warmup_steps: int = 0 """The number of warmup steps in the learning rate scheduler. Only applied to schedulers that support warmup. See lr_scheduler. """ min_snr_gamma: float | None = 5.0 """Min-SNR weighting for diffusion training was introduced in https://arxiv.org/abs/2303.09556. This strategy improves the speed of training convergence by adjusting the weight of each sample. `min_snr_gamma` acts like an an upper bound on the weight of samples with low noise levels. If `None`, then Min-SNR weighting will not be applied. If enabled, the recommended value is `min_snr_gamma = 5.0`. """ lora_rank_dim: int = 4 """The rank dimension to use for the LoRA layers. Increasing the rank dimension increases the model's expressivity, but also increases the size of the generated LoRA model. """ # The default list of target modules is based on # https://github.com/huggingface/peft/blob/8665e2b5719faa4e4b91749ddec09442927b53e0/examples/stable_diffusion/train_dreambooth.py#L49C1-L65C87 unet_lora_target_modules: list[str] = UNET_TARGET_MODULES """The list of target modules to apply LoRA layers to in the UNet model. The default list will produce a highly expressive LoRA model. For a smaller and less expressive LoRA model, the following list is recommended: ```python unet_lora_target_modules = ["to_k", "to_q", "to_v", "to_out.0"] ``` The list of target modules is passed to Hugging Face's PEFT library. See [the docs](https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.target_modules) for details. """ text_encoder_lora_target_modules: list[str] = TEXT_ENCODER_TARGET_MODULES """The list of target modules to apply LoRA layers to in the text encoder models. The default list will produce a highly expressive LoRA model. For a smaller and less expressive LoRA model, the following list is recommended: ```python text_encoder_lora_target_modules = ["fc1", "fc2", "q_proj", "k_proj", "v_proj", "out_proj"] ``` The list of target modules is passed to Hugging Face's PEFT library. See [the docs](https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.target_modules) for details. """ cache_text_encoder_outputs: bool = False """If True, the text encoder(s) will be applied to all of the captions in the dataset before starting training and the results will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the text encoders in VRAM), and speeds up training (don't have to run the text encoders for each training example). This option can only be enabled if `train_text_encoder == False` and there are no caption augmentations being applied. """ cache_vae_outputs: bool = False """If True, the VAE will be applied to all of the images in the dataset before starting training and the results will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the VAE in VRAM), and speeds up training (don't have to run the VAE encoding step). This option can only be enabled if all non-deterministic image augmentations are disabled (i.e. center_crop=True, random_flip=False). """ enable_cpu_offload_during_validation: bool = False """If True, models will be kept in CPU memory and loaded into GPU memory one-by-one while generating validation images. This reduces VRAM requirements at the cost of slower generation of validation images. """ gradient_accumulation_steps: int = 1 """The number of gradient steps to accumulate before each weight update. This value is passed to Hugging Face Accelerate. This is an alternative to increasing the batch size when training with limited VRAM. """ weight_dtype: Literal["float32", "float16", "bfloat16"] = "bfloat16" """All weights (trainable and fixed) will be cast to this precision. Lower precision dtypes require less VRAM, and result in faster training, but are more prone to issues with numerical stability. Recommendations: - `"float32"`: Use this mode if you have plenty of VRAM available. - `"bfloat16"`: Use this mode if you have limited VRAM and a GPU that supports bfloat16. - `"float16"`: Use this mode if you have limited VRAM and a GPU that does not support bfloat16. See also [`mixed_precision`][invoke_training.pipelines.stable_diffusion.lora.config.SdLoraConfig.mixed_precision]. """ # noqa: E501 mixed_precision: Literal["no", "fp16", "bf16", "fp8"] = "no" """The mixed precision mode to use. If mixed precision is enabled, then all non-trainable parameters will be cast to the specified `weight_dtype`, and trainable parameters are kept in float32 precision to avoid issues with numerical stability. This value is passed to Hugging Face Accelerate. See [`accelerate.Accelerator.mixed_precision`](https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.mixed_precision) for more details. """ # noqa: E501 xformers: bool = False """If true, use xformers for more efficient attention blocks. """ gradient_checkpointing: bool = False """Whether or not to use gradient checkpointing to save memory at the expense of a slower backward pass. Enabling gradient checkpointing slows down training by ~20%. """ max_checkpoints: int | None = None """The maximum number of checkpoints to keep. New checkpoints will replace earlier checkpoints to stay under this limit. Note that this limit is applied to 'step' and 'epoch' checkpoints separately. """ prediction_type: Literal["epsilon", "v_prediction"] | None = None """The prediction_type that will be used for training. Choose between 'epsilon' or 'v_prediction' or leave 'None'. If 'None', the prediction type of the scheduler: `noise_scheduler.config.prediction_type` is used. """ max_grad_norm: float | None = None """Max gradient norm for clipping. Set to null or 0 for no clipping. """ validation_prompts: list[str] = [] """A list of prompts that will be used to generate images throughout training for the purpose of tracking progress. See also 'validate_every_n_epochs'. """ negative_validation_prompts: list[str] | None = None """A list of negative prompts that will be applied when generating validation images. If set, this list should have the same length as 'validation_prompts'. """ num_validation_images_per_prompt: int = 4 """The number of validation images to generate for each prompt in 'validation_prompts'. Careful, validation can become quite slow if this number is too large. """ train_batch_size: int = 4 """The training batch size. """ use_masks: bool = False """If True, image masks will be applied to weight the loss during training. The dataset must contain masks for this feature to be used. """ data_loader: Annotated[ Union[ImageCaptionSDDataLoaderConfig, DreamboothSDDataLoaderConfig], Field(discriminator="type") ] @model_validator(mode="after") def check_validation_prompts(self): if self.negative_validation_prompts is not None and len(self.negative_validation_prompts) != len( self.validation_prompts ): raise ValueError( f"The number of validation_prompts ({len(self.validation_prompts)}) must match the number of " f"negative_validation_prompts ({len(self.negative_validation_prompts)})." ) return self ================================================ FILE: src/invoke_training/pipelines/stable_diffusion/lora/train.py ================================================ import itertools import json import logging import math import os import tempfile import time from pathlib import Path from typing import Literal, Optional, Union import peft import torch import torch.utils.data from accelerate.utils import set_seed from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel from diffusers.optimization import get_scheduler from torch.utils.data import DataLoader from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from invoke_training._shared.accelerator.accelerator_utils import ( get_dtype_from_str, initialize_accelerator, initialize_logging, ) from invoke_training._shared.checkpoints.checkpoint_tracker import CheckpointTracker from invoke_training._shared.data.data_loaders.dreambooth_sd_dataloader import build_dreambooth_sd_dataloader from invoke_training._shared.data.data_loaders.image_caption_sd_dataloader import build_image_caption_sd_dataloader from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import log_aspect_ratio_buckets from invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache from invoke_training._shared.optimizer.optimizer_utils import initialize_optimizer from invoke_training._shared.stable_diffusion.lora_checkpoint_utils import ( save_sd_kohya_checkpoint, save_sd_peft_checkpoint, ) from invoke_training._shared.stable_diffusion.min_snr_weighting import compute_snr from invoke_training._shared.stable_diffusion.model_loading_utils import load_models_sd from invoke_training._shared.stable_diffusion.tokenize_captions import tokenize_captions from invoke_training._shared.stable_diffusion.validation import generate_validation_images_sd from invoke_training._shared.utils.import_xformers import import_xformers from invoke_training.config.data.data_loader_config import DreamboothSDDataLoaderConfig, ImageCaptionSDDataLoaderConfig from invoke_training.pipelines.callbacks import ModelCheckpoint, ModelType, PipelineCallbacks, TrainingCheckpoint from invoke_training.pipelines.stable_diffusion.lora.config import SdLoraConfig def _save_sd_lora_checkpoint( epoch: int, step: int, unet: peft.PeftModel | None, text_encoder: peft.PeftModel | None, logger: logging.Logger, checkpoint_tracker: CheckpointTracker, lora_checkpoint_format: Literal["invoke_peft", "kohya"], callbacks: list[PipelineCallbacks] | None, ): # Prune checkpoints and get new checkpoint path. num_pruned = checkpoint_tracker.prune(1) if num_pruned > 0: logger.info(f"Pruned {num_pruned} checkpoint(s).") save_path = checkpoint_tracker.get_path(epoch=epoch, step=step) if lora_checkpoint_format == "invoke_peft": model_type = ModelType.SD1_LORA_PEFT save_sd_peft_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder) elif lora_checkpoint_format == "kohya": model_type = ModelType.SD1_LORA_KOHYA save_sd_kohya_checkpoint(Path(save_path), unet=unet, text_encoder=text_encoder) else: raise ValueError(f"Unsupported lora_checkpoint_format: '{lora_checkpoint_format}'.") if callbacks is not None: for cb in callbacks: cb.on_save_checkpoint( TrainingCheckpoint( models=[ModelCheckpoint(file_path=save_path, model_type=model_type)], epoch=epoch, step=step ) ) def _build_data_loader( data_loader_config: Union[ImageCaptionSDDataLoaderConfig, DreamboothSDDataLoaderConfig], batch_size: int, use_masks: bool = False, text_encoder_output_cache_dir: Optional[str] = None, vae_output_cache_dir: Optional[str] = None, shuffle: bool = True, sequential_batching: bool = False, ) -> DataLoader: if data_loader_config.type == "IMAGE_CAPTION_SD_DATA_LOADER": return build_image_caption_sd_dataloader( config=data_loader_config, batch_size=batch_size, use_masks=use_masks, text_encoder_output_cache_dir=text_encoder_output_cache_dir, text_encoder_cache_field_to_output_field={"text_encoder_output": "text_encoder_output"}, vae_output_cache_dir=vae_output_cache_dir, shuffle=shuffle, ) elif data_loader_config.type == "DREAMBOOTH_SD_DATA_LOADER": if use_masks: raise NotImplementedError("Masks are not yet supported for DreamBooth data loaders.") return build_dreambooth_sd_dataloader( config=data_loader_config, batch_size=batch_size, text_encoder_output_cache_dir=text_encoder_output_cache_dir, text_encoder_cache_field_to_output_field={"text_encoder_output": "text_encoder_output"}, vae_output_cache_dir=vae_output_cache_dir, shuffle=shuffle, sequential_batching=sequential_batching, ) else: raise ValueError(f"Unsupported data loader config type: '{data_loader_config.type}'.") def cache_text_encoder_outputs( cache_dir: str, config: SdLoraConfig, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel ): """Run the text encoder on all captions in the dataset and cache the results to disk. Args: cache_dir (str): The directory where the results will be cached. config (SdLoraConfig): Training config. tokenizer (CLIPTokenizer): The tokenizer. text_encoder (CLIPTextModel): The text_encoder. """ data_loader = _build_data_loader( data_loader_config=config.data_loader, batch_size=config.train_batch_size, shuffle=False, sequential_batching=True, ) cache = TensorDiskCache(cache_dir) for data_batch in tqdm(data_loader): caption_token_ids = tokenize_captions(tokenizer, data_batch["caption"]).to(text_encoder.device) text_encoder_output_batch = text_encoder(caption_token_ids)[0] # Split batch before caching. for i in range(len(data_batch["id"])): cache.save(data_batch["id"][i], {"text_encoder_output": text_encoder_output_batch[i]}) def cache_vae_outputs(cache_dir: str, data_loader: DataLoader, vae: AutoencoderKL): """Run the VAE on all images in the dataset and cache the results to disk.""" cache = TensorDiskCache(cache_dir) for data_batch in tqdm(data_loader): latents = vae.encode(data_batch["image"].to(device=vae.device, dtype=vae.dtype)).latent_dist.sample() latents = latents * vae.config.scaling_factor # Split batch before caching. for i in range(len(data_batch["id"])): data = { "vae_output": latents[i], "original_size_hw": data_batch["original_size_hw"][i], "crop_top_left_yx": data_batch["crop_top_left_yx"][i], } if "mask" in data_batch: data["mask"] = data_batch["mask"][i] cache.save(data_batch["id"][i], data) def train_forward( # noqa: C901 config: SdLoraConfig, data_batch: dict, vae: AutoencoderKL, noise_scheduler: DDPMScheduler, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, unet: UNet2DConditionModel, weight_dtype: torch.dtype, use_masks: bool = False, min_snr_gamma: float | None = None, ) -> torch.Tensor: """Run the forward training pass for a single data_batch. Returns: torch.Tensor: Loss """ # Convert images to latent space. # The VAE output may have been cached and included in the data_batch. If not, we calculate it here. latents = data_batch.get("vae_output", None) if latents is None: latents = vae.encode(data_batch["image"].to(dtype=weight_dtype)).latent_dist.sample() latents = latents * vae.config.scaling_factor # Sample noise that we'll add to the latents. noise = torch.randn_like(latents) batch_size = latents.shape[0] # Sample a random timestep for each image. timesteps = torch.randint( 0, noise_scheduler.config.num_train_timesteps, (batch_size,), device=latents.device, ) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep (this is the forward # diffusion process). noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning. # The text_encoder_output may have been cached and included in the data_batch. If not, we calculate it here. encoder_hidden_states = data_batch.get("text_encoder_output", None) if encoder_hidden_states is None: caption_token_ids = tokenize_captions(tokenizer, data_batch["caption"]).to(text_encoder.device) encoder_hidden_states = text_encoder(caption_token_ids)[0].to(dtype=weight_dtype) # Get the target for loss depending on the prediction type. if config.prediction_type is not None: # Set the prediction_type of scheduler if it's defined in config. noise_scheduler.register_to_config(prediction_type=config.prediction_type) if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") # Predict the noise residual. model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample min_snr_weights = None if min_snr_gamma is not None: # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) # Note: We divide by snr here per Section 4.2 of the paper, since we are predicting the noise instead of x_0. # w_t = min(1, SNR(t)) / SNR(t) min_snr_weights = torch.clamp(snr, max=min_snr_gamma) / snr if noise_scheduler.config.prediction_type == "epsilon": pass elif noise_scheduler.config.prediction_type == "v_prediction": # Velocity objective needs to be floored to an SNR weight of one. min_snr_weights = min_snr_weights + 1 else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction="none") if use_masks: # TODO(ryand): As a future performance optimization, we may want to do this resizing in the dataloader. mask = data_batch["mask"].to(dtype=loss.dtype, device=loss.device) _, _, latent_h, latent_w = loss.shape mask = torch.nn.functional.interpolate(mask, size=(latent_h, latent_w), mode="nearest") loss = loss * mask # Mean-reduce the loss along all dimensions except for the batch dimension. loss = loss.mean(dim=list(range(1, len(loss.shape)))) # Apply min_snr_weights. if min_snr_weights is not None: loss = loss * min_snr_weights # Apply per-example loss weights. if "loss_weight" in data_batch: loss = loss * data_batch["loss_weight"] return loss.mean() def train(config: SdLoraConfig, callbacks: list[PipelineCallbacks] | None = None): # noqa: C901 # Give a clear error message if an unsupported base model was chosen. # TODO(ryan): Update this check to work with single-file SD checkpoints. # check_base_model_version( # {BaseModelVersionEnum.STABLE_DIFFUSION_V1, BaseModelVersionEnum.STABLE_DIFFUSION_V2}, # config.model, # local_files_only=False, # ) # Create a timestamped directory for all outputs. out_dir = os.path.join(config.base_output_dir, f"{time.time()}") ckpt_dir = os.path.join(out_dir, "checkpoints") os.makedirs(ckpt_dir) accelerator = initialize_accelerator( out_dir, config.gradient_accumulation_steps, config.mixed_precision, config.report_to ) logger = initialize_logging(os.path.basename(__file__), accelerator) # Set the accelerate seed. if config.seed is not None: set_seed(config.seed) # Log the accelerator configuration from every process to help with debugging. logger.info(accelerator.state, main_process_only=False) logger.info("Starting LoRA Training.") logger.info(f"Configuration:\n{json.dumps(config.dict(), indent=2, default=str)}") logger.info(f"Output dir: '{out_dir}'") # Write the configuration to disk. with open(os.path.join(out_dir, "config.json"), "w") as f: json.dump(config.dict(), f, indent=2, default=str) weight_dtype = get_dtype_from_str(config.weight_dtype) logger.info("Loading models.") tokenizer, noise_scheduler, text_encoder, vae, unet = load_models_sd( logger=logger, model_name_or_path=config.model, hf_variant=config.hf_variant, base_embeddings=config.base_embeddings, dtype=weight_dtype, ) if config.xformers: import_xformers() # TODO(ryand): There is a known issue if xformers is enabled when training in mixed precision where xformers # will fail because Q, K, V have different dtypes. unet.enable_xformers_memory_efficient_attention() vae.enable_xformers_memory_efficient_attention() # Prepare text encoder output cache. text_encoder_output_cache_dir_name = None if config.cache_text_encoder_outputs: # TODO(ryand): Think about how to better check if it is safe to cache the text encoder outputs. Currently, there # are a number of configurations that would cause variation in the text encoder outputs and should not be used # with caching. if config.train_text_encoder: raise ValueError("'cache_text_encoder_outputs' and 'train_text_encoder' cannot both be True.") # We use a temporary directory for the cache. The directory will automatically be cleaned up when # tmp_text_encoder_output_cache_dir is destroyed. tmp_text_encoder_output_cache_dir = tempfile.TemporaryDirectory() text_encoder_output_cache_dir_name = tmp_text_encoder_output_cache_dir.name if accelerator.is_local_main_process: # Only the main process should populate the cache. logger.info(f"Generating text encoder output cache ('{text_encoder_output_cache_dir_name}').") text_encoder.to(accelerator.device, dtype=weight_dtype) cache_text_encoder_outputs(text_encoder_output_cache_dir_name, config, tokenizer, text_encoder) # Move the text_encoder back to the CPU, because it is not needed for training. text_encoder.to("cpu") accelerator.wait_for_everyone() else: text_encoder.to(accelerator.device, dtype=weight_dtype) # Prepare VAE output cache. vae_output_cache_dir_name = None if config.cache_vae_outputs: if config.data_loader.random_flip: raise ValueError("'cache_vae_outputs' cannot be True if 'random_flip' is True.") if not config.data_loader.center_crop: raise ValueError("'cache_vae_outputs' cannot be True if 'center_crop' is False.") # We use a temporary directory for the cache. The directory will automatically be cleaned up when # tmp_vae_output_cache_dir is destroyed. tmp_vae_output_cache_dir = tempfile.TemporaryDirectory() vae_output_cache_dir_name = tmp_vae_output_cache_dir.name if accelerator.is_local_main_process: # Only the main process should populate the cache. logger.info(f"Generating VAE output cache ('{vae_output_cache_dir_name}').") vae.to(accelerator.device, dtype=weight_dtype) data_loader = _build_data_loader( data_loader_config=config.data_loader, batch_size=config.train_batch_size, use_masks=config.use_masks, shuffle=False, sequential_batching=True, ) cache_vae_outputs(vae_output_cache_dir_name, data_loader, vae) # Move the VAE back to the CPU, because it is not needed for training. vae.to("cpu") accelerator.wait_for_everyone() else: vae.to(accelerator.device, dtype=weight_dtype) unet.to(accelerator.device, dtype=weight_dtype) # Add LoRA layers to the models being trained. trainable_param_groups = [] all_trainable_models: list[peft.PeftModel] = [] def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = None) -> peft.PeftModel: peft_model = peft.get_peft_model(model, lora_config) peft_model.print_trainable_parameters() # Populate `trainable_param_groups`, to be passed to the optimizer. param_group = {"params": list(filter(lambda p: p.requires_grad, peft_model.parameters()))} if lr is not None: param_group["lr"] = lr trainable_param_groups.append(param_group) # Populate all_trainable_models. all_trainable_models.append(peft_model) peft_model.train() return peft_model # Add LoRA layers to the model. if config.train_unet: unet_lora_config = peft.LoraConfig( r=config.lora_rank_dim, # TODO(ryand): Diffusers uses lora_alpha=config.lora_rank_dim. Is that preferred? lora_alpha=1.0, target_modules=config.unet_lora_target_modules, ) unet = inject_lora_layers(unet, unet_lora_config, lr=config.unet_learning_rate) if config.train_text_encoder: text_encoder_lora_config = peft.LoraConfig( r=config.lora_rank_dim, lora_alpha=1.0, # init_lora_weights="gaussian", target_modules=config.text_encoder_lora_target_modules, ) text_encoder = inject_lora_layers(text_encoder, text_encoder_lora_config, lr=config.text_encoder_learning_rate) # If mixed_precision is enabled, cast all trainable params to float32. if config.mixed_precision != "no": for trainable_model in all_trainable_models: for param in trainable_model.parameters(): if param.requires_grad: param.data = param.to(torch.float32) if config.gradient_checkpointing: # We want to enable gradient checkpointing in the UNet regardless of whether it is being trained. unet.enable_gradient_checkpointing() # unet must be in train() mode for gradient checkpointing to take effect. # At the time of writing, the unet dropout probabilities default to 0, so putting the unet in train mode does # not change its forward behavior. unet.train() if config.train_text_encoder: text_encoder.gradient_checkpointing_enable() # The text encoder must be in train() mode for gradient checkpointing to take effect. This should # already be the case, since we are training the text_encoder, but we do it explicitly to make it clear # that this is required. # At the time of writing, the text encoder dropout probabilities default to 0, so putting the text # encoders in train mode does not change their forward behavior. text_encoder.train() # Set requires_grad = True on the first parameters of the text encoders. Without this, the text encoder # LoRA weights would have 0 gradients, and so would not get trained. Note that the set of # trainable_param_groups has already been populated - the embeddings will not be trained. text_encoder.text_model.embeddings.requires_grad_(True) optimizer = initialize_optimizer(config.optimizer, trainable_param_groups) data_loader = _build_data_loader( data_loader_config=config.data_loader, batch_size=config.train_batch_size, use_masks=config.use_masks, text_encoder_output_cache_dir=text_encoder_output_cache_dir_name, vae_output_cache_dir=vae_output_cache_dir_name, ) log_aspect_ratio_buckets(logger=logger, batch_sampler=data_loader.batch_sampler) assert sum([config.max_train_steps is not None, config.max_train_epochs is not None]) == 1 assert sum([config.save_every_n_steps is not None, config.save_every_n_epochs is not None]) == 1 assert sum([config.validate_every_n_steps is not None, config.validate_every_n_epochs is not None]) == 1 # A "step" represents a single weight update operation (i.e. takes into account gradient accumulation steps). # math.ceil(...) is used in calculating the num_steps_per_epoch, because by default an optimizer step is taken when # the end of the dataloader is reached, even if gradient_accumulation_steps hasn't been reached. num_steps_per_epoch = math.ceil(len(data_loader) / config.gradient_accumulation_steps) num_train_steps = config.max_train_steps or config.max_train_epochs * num_steps_per_epoch num_train_epochs = math.ceil(num_train_steps / num_steps_per_epoch) # TODO(ryand): Test in a distributed training environment and more clearly document the rationale for scaling steps # by the number of processes. This scaling logic was copied from the diffusers example training code, but it appears # in many places so I don't know where it originated. Internally, accelerate makes one LR scheduler step per process # (https://github.com/huggingface/accelerate/blame/49cb83a423f2946059117d8bb39b7c8747d29d80/src/accelerate/scheduler.py#L72-L82), # so the scaling here simply reverses that behaviour. lr_scheduler: torch.optim.lr_scheduler.LRScheduler = get_scheduler( config.lr_scheduler, optimizer=optimizer, num_warmup_steps=config.lr_warmup_steps * accelerator.num_processes, num_training_steps=num_train_steps * accelerator.num_processes, ) prepared_result: tuple[ UNet2DConditionModel, CLIPTextModel, torch.optim.Optimizer, torch.utils.data.DataLoader, torch.optim.lr_scheduler.LRScheduler, ] = accelerator.prepare( unet, text_encoder, optimizer, data_loader, lr_scheduler, # Disable automatic device placement for text_encoder if the text encoder outputs were cached. device_placement=[True, not config.cache_text_encoder_outputs, True, True, True], ) unet, text_encoder, optimizer, data_loader, lr_scheduler = prepared_result if accelerator.is_main_process: accelerator.init_trackers("lora_training") # Tensorboard uses markdown formatting, so we wrap the config json in a code block. accelerator.log({"configuration": f"```json\n{json.dumps(config.dict(), indent=2, default=str)}\n```\n"}) checkpoint_tracker = CheckpointTracker( base_dir=ckpt_dir, prefix="checkpoint", max_checkpoints=config.max_checkpoints, extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None, ) # Train! total_batch_size = config.train_batch_size * accelerator.num_processes * config.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num batches = {len(data_loader)}") logger.info(f" Instantaneous batch size per device = {config.train_batch_size}") logger.info(f" Gradient accumulation steps = {config.gradient_accumulation_steps}") logger.info(f" Parallel processes = {accelerator.num_processes}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Total optimization steps = {num_train_steps}") logger.info(f" Total epochs = {num_train_epochs}") global_step = 0 first_epoch = 0 completed_epochs = 0 progress_bar = tqdm( range(global_step, num_train_steps), # Only show the progress bar once on each machine. disable=not accelerator.is_local_main_process, ) progress_bar.set_description("Steps") def save_checkpoint(num_completed_epochs: int, num_completed_steps: int): accelerator.wait_for_everyone() if accelerator.is_main_process: _save_sd_lora_checkpoint( epoch=num_completed_epochs, step=num_completed_steps, unet=accelerator.unwrap_model(unet) if config.train_unet else None, text_encoder=accelerator.unwrap_model(text_encoder) if config.train_text_encoder else None, logger=logger, checkpoint_tracker=checkpoint_tracker, lora_checkpoint_format=config.lora_checkpoint_format, callbacks=callbacks, ) accelerator.wait_for_everyone() def validate(num_completed_epochs: int, num_completed_steps: int): accelerator.wait_for_everyone() if accelerator.is_main_process: generate_validation_images_sd( epoch=num_completed_epochs, step=num_completed_steps, out_dir=out_dir, accelerator=accelerator, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, noise_scheduler=noise_scheduler, unet=unet, config=config, logger=logger, callbacks=callbacks, ) accelerator.wait_for_everyone() for epoch in range(first_epoch, num_train_epochs): train_loss = 0.0 for data_batch_idx, data_batch in enumerate(data_loader): with accelerator.accumulate(unet, text_encoder): loss = train_forward( config=config, data_batch=data_batch, vae=vae, noise_scheduler=noise_scheduler, tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, weight_dtype=weight_dtype, use_masks=config.use_masks, min_snr_gamma=config.min_snr_gamma, ) # Gather the losses across all processes for logging (if we use distributed training). # TODO(ryand): Test that this works properly with distributed training. avg_loss = accelerator.gather(loss.repeat(config.train_batch_size)).mean() train_loss += avg_loss.item() / config.gradient_accumulation_steps # Backpropagate. accelerator.backward(loss) if accelerator.sync_gradients and config.max_grad_norm is not None: params_to_clip = itertools.chain.from_iterable([m.parameters() for m in all_trainable_models]) accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) # Checks if the accelerator has performed an optimization step behind the scenes. if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 completed_epochs = epoch if (data_batch_idx + 1) < len(data_loader) else epoch + 1 log = {"train_loss": train_loss} lrs = lr_scheduler.get_last_lr() if config.train_unet: # When training the UNet, it will always be the first parameter group. log["lr/unet"] = float(lrs[0]) if config.optimizer.optimizer_type == "Prodigy": log["lr/d*lr/unet"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] if config.train_text_encoder: # When training the text encoder, it will always be the last parameter group. log["lr/text_encoder"] = float(lrs[-1]) if config.optimizer.optimizer_type == "Prodigy": log["lr/d*lr/text_encoder"] = optimizer.param_groups[-1]["d"] * optimizer.param_groups[-1]["lr"] accelerator.log(log, step=global_step) train_loss = 0.0 # global_step represents the *number of completed steps* at this point. if config.save_every_n_steps is not None and global_step % config.save_every_n_steps == 0: save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step) if ( config.validate_every_n_steps is not None and global_step % config.validate_every_n_steps == 0 and len(config.validation_prompts) > 0 ): validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step) logs = { "step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], } progress_bar.set_postfix(**logs) if global_step >= num_train_steps: break # Save a checkpoint every n epochs. if config.save_every_n_epochs is not None and completed_epochs % config.save_every_n_epochs == 0: save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step) # Generate validation images every n epochs. if ( config.validate_every_n_epochs is not None and completed_epochs % config.validate_every_n_epochs == 0 and len(config.validation_prompts) > 0 ): validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step) accelerator.end_training() ================================================ FILE: src/invoke_training/pipelines/stable_diffusion/textual_inversion/__init__.py ================================================ ================================================ FILE: src/invoke_training/pipelines/stable_diffusion/textual_inversion/config.py ================================================ from typing import Literal from pydantic import model_validator from invoke_training.config.base_pipeline_config import BasePipelineConfig from invoke_training.config.data.data_loader_config import TextualInversionSDDataLoaderConfig from invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig, ProdigyOptimizerConfig class SdTextualInversionConfig(BasePipelineConfig): type: Literal["SD_TEXTUAL_INVERSION"] = "SD_TEXTUAL_INVERSION" """Must be `SD_TEXTUAL_INVERSION`. This is what differentiates training pipeline types. """ model: str """Name or path of the base model to train. Can be in diffusers format, or a single stable diffusion checkpoint file. (E.g. `"runwayml/stable-diffusion-v1-5"`, `"stabilityai/stable-diffusion-xl-base-1.0"`, `"/path/to/local/model.safetensors"`, etc.) The model architecture must match the training pipeline being run. For example, if running a Textual Inversion SDXL pipeline, then `model` must refer to an SDXL model. """ hf_variant: str | None = "fp16" """The Hugging Face Hub model variant to use. Only applies if `model` is a Hugging Face Hub model name. """ # Helpful discussion for understanding how this works at inference time: # https://github.com/huggingface/diffusers/pull/3144#discussion_r1172413509 num_vectors: int = 1 """Note: `num_vectors` can be overridden by `initial_phrase`. The number of textual inversion embedding vectors that will be used to learn the concept. Increasing the `num_vectors` enables the model to learn more complex concepts, but has the following drawbacks: - greater risk of overfitting - increased size of the resulting output file - consumes more of the prompt capacity at inference time Typical values for `num_vectors` are in the range [1, 16]. As a rule of thumb, `num_vectors` can be increased as the size of the dataset increases (without overfitting). """ placeholder_token: str """The special word to associate the learned embeddings with. Choose a unique token that is unlikely to already exist in the tokenizer's vocabulary. """ initializer_token: str | None = None """Note: Exactly one of `initializer_token`, `initial_embedding_file`, or `initial_phrase` should be set. A vocabulary token to use as an initializer for the placeholder token. It should be a single word that roughly describes the object or style that you're trying to train on. Must map to a single tokenizer token. For example, if you are training on a dataset of images of your pet dog, a good choice would be `dog`. """ initial_embedding_file: str | None = None """Note: Exactly one of `initializer_token`, `initial_embedding_file`, or `initial_phrase` should be set. Path to an existing TI embedding that will be used to initialize the embedding being trained. The placeholder token in the file must match the `placeholder_token` field. Either `initializer_token` or `initial_embedding_file` should be set. """ initial_phrase: str | None = None """Note: Exactly one of `initializer_token`, `initial_embedding_file`, or `initial_phrase` should be set. A phrase that will be used to initialize the placeholder token embedding. The phrase will be tokenized, and the corresponding embeddings will be used to initialize the placeholder tokens. The number of embedding vectors will be inferred from the length of the tokenized phrase, so keep the phrase short. The consequences of training a large number of embedding vectors are discussed in the `num_vectors` field documentation. For example, if you are training on a dataset of images of pokemon, you might use `pokemon sketch white background`. """ optimizer: AdamOptimizerConfig | ProdigyOptimizerConfig = AdamOptimizerConfig() lr_scheduler: Literal[ "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup" ] = "constant" lr_warmup_steps: int = 0 """The number of warmup steps in the learning rate scheduler. Only applied to schedulers that support warmup. See lr_scheduler. """ min_snr_gamma: float | None = 5.0 """Min-SNR weighting for diffusion training was introduced in https://arxiv.org/abs/2303.09556. This strategy improves the speed of training convergence by adjusting the weight of each sample. `min_snr_gamma` acts like an an upper bound on the weight of samples with low noise levels. If `None`, then Min-SNR weighting will not be applied. If enabled, the recommended value is `min_snr_gamma = 5.0`. """ cache_vae_outputs: bool = False """If True, the VAE will be applied to all of the images in the dataset before starting training and the results will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the VAE in VRAM), and speeds up training (don't have to run the VAE encoding step). This option can only be enabled if all non-deterministic image augmentations are disabled (i.e. `center_crop=True`, `random_flip=False`, etc.). """ enable_cpu_offload_during_validation: bool = False """If True, models will be kept in CPU memory and loaded into GPU memory one-by-one while generating validation images. This reduces VRAM requirements at the cost of slower generation of validation images. """ gradient_accumulation_steps: int = 1 """The number of gradient steps to accumulate before each weight update. This is an alternative to increasing the `train_batch_size` when training with limited VRAM. """ weight_dtype: Literal["float32", "float16", "bfloat16"] = "bfloat16" """All weights (trainable and fixed) will be cast to this precision. Lower precision dtypes require less VRAM, and result in faster training, but are more prone to issues with numerical stability. Recommendations: - `"float32"`: Use this mode if you have plenty of VRAM available. - `"bfloat16"`: Use this mode if you have limited VRAM and a GPU that supports bfloat16. - `"float16"`: Use this mode if you have limited VRAM and a GPU that does not support bfloat16. See also [`mixed_precision`][invoke_training.pipelines.stable_diffusion.textual_inversion.config.SdTextualInversionConfig.mixed_precision]. """ # noqa: E501 mixed_precision: Literal["no", "fp16", "bf16", "fp8"] = "no" """The mixed precision mode to use. If mixed precision is enabled, then all non-trainable parameters will be cast to the specified `weight_dtype`, and trainable parameters are kept in float32 precision to avoid issues with numerical stability. This value is passed to Hugging Face Accelerate. See [`accelerate.Accelerator.mixed_precision`](https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.mixed_precision) for more details. """ # noqa: E501 xformers: bool = False """If `True`, use xformers for more efficient attention blocks. """ gradient_checkpointing: bool = False """Whether or not to use gradient checkpointing to save memory at the expense of a slower backward pass. Enabling gradient checkpointing slows down training by ~20%. """ max_checkpoints: int | None = None """The maximum number of checkpoints to keep. New checkpoints will replace earlier checkpoints to stay under this limit. Note that this limit is applied to 'step' and 'epoch' checkpoints separately. """ prediction_type: Literal["epsilon", "v_prediction"] | None = None """The prediction type that will be used for training. If `None`, the prediction type will be inferred from the scheduler. """ max_grad_norm: float | None = None """Maximum gradient norm for gradient clipping. Set to `null` or 0 for no clipping. """ validation_prompts: list[str] = [] """A list of prompts that will be used to generate images throughout training for the purpose of tracking progress. """ negative_validation_prompts: list[str] | None = None """A list of negative prompts that will be applied when generating validation images. If set, this list should have the same length as 'validation_prompts'. """ num_validation_images_per_prompt: int = 4 """The number of validation images to generate for each prompt in `validation_prompts`. Careful, validation can become very slow if this number is too large. """ train_batch_size: int = 4 """The training batch size. """ use_masks: bool = False """If True, image masks will be applied to weight the loss during training. The dataset must contain masks for this feature to be used. """ data_loader: TextualInversionSDDataLoaderConfig """The data configuration. See [`TextualInversionSDDataLoaderConfig`][invoke_training.config.data.data_loader_config.TextualInversionSDDataLoaderConfig] for details. """ @model_validator(mode="after") def check_validation_prompts(self): if self.negative_validation_prompts is not None and len(self.negative_validation_prompts) != len( self.validation_prompts ): raise ValueError( f"The number of validation_prompts ({len(self.validation_prompts)}) must match the number of " f"negative_validation_prompts ({len(self.negative_validation_prompts)})." ) return self ================================================ FILE: src/invoke_training/pipelines/stable_diffusion/textual_inversion/train.py ================================================ import json import logging import math import os import tempfile import time import torch from accelerate import Accelerator from accelerate.utils import set_seed from diffusers.optimization import get_scheduler from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer, PreTrainedTokenizer from invoke_training._shared.accelerator.accelerator_utils import ( get_dtype_from_str, initialize_accelerator, initialize_logging, ) from invoke_training._shared.checkpoints.checkpoint_tracker import CheckpointTracker from invoke_training._shared.checkpoints.serialization import save_state_dict from invoke_training._shared.data.data_loaders.textual_inversion_sd_dataloader import ( build_textual_inversion_sd_dataloader, ) from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import log_aspect_ratio_buckets from invoke_training._shared.optimizer.optimizer_utils import initialize_optimizer from invoke_training._shared.stable_diffusion.model_loading_utils import load_models_sd from invoke_training._shared.stable_diffusion.textual_inversion import ( initialize_placeholder_tokens_from_initial_embedding, initialize_placeholder_tokens_from_initial_phrase, initialize_placeholder_tokens_from_initializer_token, restore_original_embeddings, ) from invoke_training._shared.stable_diffusion.validation import generate_validation_images_sd from invoke_training._shared.utils.import_xformers import import_xformers from invoke_training.pipelines.callbacks import ModelCheckpoint, ModelType, PipelineCallbacks, TrainingCheckpoint from invoke_training.pipelines.stable_diffusion.lora.train import cache_vae_outputs, train_forward from invoke_training.pipelines.stable_diffusion.textual_inversion.config import SdTextualInversionConfig def _save_ti_embeddings( epoch: int, step: int, text_encoder: CLIPTextModel, placeholder_token_ids: list[int], accelerator: Accelerator, logger: logging.Logger, checkpoint_tracker: CheckpointTracker, callbacks: list[PipelineCallbacks] | None, ): """Save a Textual Inversion checkpoint. Old checkpoints are deleted if necessary to respect the checkpoint_tracker limits. """ # Prune checkpoints and get new checkpoint path. num_pruned = checkpoint_tracker.prune(1) if num_pruned > 0: logger.info(f"Pruned {num_pruned} checkpoint(s).") save_path = checkpoint_tracker.get_path(epoch=epoch, step=step) learned_embeds = ( accelerator.unwrap_model(text_encoder) .get_input_embeddings() .weight[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] ) learned_embeds_dict = {"emb_params": learned_embeds.detach().cpu().to(torch.float32)} save_state_dict(learned_embeds_dict, save_path) if callbacks is not None: for cb in callbacks: cb.on_save_checkpoint( TrainingCheckpoint( models=[ModelCheckpoint(file_path=save_path, model_type=ModelType.SD1_TEXTUAL_INVERSION)], epoch=epoch, step=step, ) ) def _initialize_placeholder_tokens( config: SdTextualInversionConfig, tokenizer: CLIPTokenizer, text_encoder: PreTrainedTokenizer, logger: logging.Logger, ) -> tuple[list[str], list[int]]: """Prepare the tokenizer and text_encoder for TI training. - Add the placeholder tokens to the tokenizer. - Add new token embeddings to the text_encoder for each of the placeholder tokens. - Initialize the new token embeddings from either an existing token, or an initial TI embedding file. """ if ( sum( [ config.initializer_token is not None, config.initial_embedding_file is not None, config.initial_phrase is not None, ] ) != 1 ): raise ValueError( "Exactly one of 'initializer_token', 'initial_embedding_file', or 'initial_phrase' should be set." ) if config.initializer_token is not None: placeholder_tokens, placeholder_token_ids = initialize_placeholder_tokens_from_initializer_token( tokenizer=tokenizer, text_encoder=text_encoder, initializer_token=config.initializer_token, placeholder_token=config.placeholder_token, num_vectors=config.num_vectors, logger=logger, ) elif config.initial_embedding_file is not None: placeholder_tokens, placeholder_token_ids = initialize_placeholder_tokens_from_initial_embedding( tokenizer=tokenizer, text_encoder=text_encoder, initial_embedding_file=config.initial_embedding_file, placeholder_token=config.placeholder_token, num_vectors=config.num_vectors, ) elif config.initial_phrase is not None: placeholder_tokens, placeholder_token_ids = initialize_placeholder_tokens_from_initial_phrase( tokenizer=tokenizer, text_encoder=text_encoder, initial_phrase=config.initial_phrase, placeholder_token=config.placeholder_token, ) else: raise ValueError( "Exactly one of 'initializer_token', 'initial_embedding_file', or 'initial_phrase' should be set." ) return placeholder_tokens, placeholder_token_ids def train(config: SdTextualInversionConfig, callbacks: list[PipelineCallbacks] | None = None): # noqa: C901 # Create a timestamped directory for all outputs. out_dir = os.path.join(config.base_output_dir, f"{time.time()}") ckpt_dir = os.path.join(out_dir, "checkpoints") os.makedirs(ckpt_dir) accelerator = initialize_accelerator( out_dir, config.gradient_accumulation_steps, config.mixed_precision, config.report_to ) logger = initialize_logging(os.path.basename(__file__), accelerator) # Set the accelerate seed. if config.seed is not None: set_seed(config.seed) # Log the accelerator configuration from every process to help with debugging. logger.info(accelerator.state, main_process_only=False) logger.info("Starting Textual Inversion Training.") logger.info(f"Configuration:\n{json.dumps(config.dict(), indent=2, default=str)}") logger.info(f"Output dir: '{out_dir}'") # Write the configuration to disk. with open(os.path.join(out_dir, "config.json"), "w") as f: json.dump(config.dict(), f, indent=2, default=str) weight_dtype = get_dtype_from_str(config.weight_dtype) logger.info("Loading models.") tokenizer, noise_scheduler, text_encoder, vae, unet = load_models_sd( logger=logger, model_name_or_path=config.model, hf_variant=config.hf_variant, dtype=weight_dtype ) placeholder_tokens, placeholder_token_ids = _initialize_placeholder_tokens( config=config, tokenizer=tokenizer, text_encoder=text_encoder, logger=logger ) logger.info(f"Initialized {len(placeholder_tokens)} placeholder tokens: {placeholder_tokens}.") # All parameters of the VAE, UNet, and text encoder are currently frozen. Just unfreeze the token embeddings in the # text encoder. text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) if config.gradient_checkpointing: # We want to enable gradient checkpointing in the UNet regardless of whether it is being trained. unet.enable_gradient_checkpointing() # unet must be in train() mode for gradient checkpointing to take effect. # At the time of writing, the unet dropout probabilities default to 0, so putting the unet in train mode does # not change its forward behavior. unet.train() # The text_encoder will be put in .train() mode later, so we don't need to worry about that here. # Note: There are some weird interactions gradient checkpointing and requires_grad_() when training a # text_encoder LoRA. If this code ever gets copied elsewhere, make sure to take a look at how this is handled in # other training pipelines. text_encoder.gradient_checkpointing_enable() if config.xformers: import_xformers() unet.enable_xformers_memory_efficient_attention() vae.enable_xformers_memory_efficient_attention() # Prepare VAE output cache. vae_output_cache_dir_name = None if config.cache_vae_outputs: if config.data_loader.random_flip: raise ValueError("'cache_vae_outputs' cannot be True if 'random_flip' is True.") if not config.data_loader.center_crop: raise ValueError("'cache_vae_outputs' cannot be True if 'center_crop' is False.") # We use a temporary directory for the cache. The directory will automatically be cleaned up when # tmp_vae_output_cache_dir is destroyed. tmp_vae_output_cache_dir = tempfile.TemporaryDirectory() vae_output_cache_dir_name = tmp_vae_output_cache_dir.name if accelerator.is_local_main_process: # Only the main process should populate the cache. logger.info(f"Generating VAE output cache ('{vae_output_cache_dir_name}').") vae.to(accelerator.device, dtype=weight_dtype) data_loader = build_textual_inversion_sd_dataloader( config=config.data_loader, placeholder_token=config.placeholder_token, batch_size=config.train_batch_size, use_masks=config.use_masks, shuffle=False, ) cache_vae_outputs(vae_output_cache_dir_name, data_loader, vae) # Move the VAE back to the CPU, because it is not needed for training. vae.to("cpu") accelerator.wait_for_everyone() else: vae.to(accelerator.device, dtype=weight_dtype) unet.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) # Initialize the optimizer to only optimize the token embeddings. optimizer = initialize_optimizer(config.optimizer, text_encoder.get_input_embeddings().parameters()) data_loader = build_textual_inversion_sd_dataloader( config=config.data_loader, placeholder_token=config.placeholder_token, batch_size=config.train_batch_size, use_masks=config.use_masks, vae_output_cache_dir=vae_output_cache_dir_name, ) log_aspect_ratio_buckets(logger=logger, batch_sampler=data_loader.batch_sampler) assert sum([config.max_train_steps is not None, config.max_train_epochs is not None]) == 1 assert sum([config.save_every_n_steps is not None, config.save_every_n_epochs is not None]) == 1 assert sum([config.validate_every_n_steps is not None, config.validate_every_n_epochs is not None]) == 1 # A "step" represents a single weight update operation (i.e. takes into account gradient accumulation steps). # math.ceil(...) is used in calculating the num_steps_per_epoch, because by default an optimizer step is taken when # the end of the dataloader is reached, even if gradient_accumulation_steps hasn't been reached. num_steps_per_epoch = math.ceil(len(data_loader) / config.gradient_accumulation_steps) num_train_steps = config.max_train_steps or config.max_train_epochs * num_steps_per_epoch num_train_epochs = math.ceil(num_train_steps / num_steps_per_epoch) # TODO(ryand): Test in a distributed training environment and more clearly document the rationale for scaling steps # by the number of processes. This scaling logic was copied from the diffusers example training code, but it appears # in many places so I don't know where it originated. Internally, accelerate makes one LR scheduler step per process # (https://github.com/huggingface/accelerate/blame/49cb83a423f2946059117d8bb39b7c8747d29d80/src/accelerate/scheduler.py#L72-L82), # so the scaling here simply reverses that behaviour. lr_scheduler: torch.optim.lr_scheduler.LRScheduler = get_scheduler( config.lr_scheduler, optimizer=optimizer, num_warmup_steps=config.lr_warmup_steps * accelerator.num_processes, num_training_steps=num_train_steps * accelerator.num_processes, ) # Prepare everything with our `accelerator`. text_encoder, optimizer, data_loader, lr_scheduler = accelerator.prepare( text_encoder, optimizer, data_loader, lr_scheduler ) prepared_result: tuple[ CLIPTextModel, torch.optim.Optimizer, torch.utils.data.DataLoader, torch.optim.lr_scheduler.LRScheduler ] = accelerator.prepare(text_encoder, optimizer, data_loader, lr_scheduler) text_encoder, optimizer, data_loader, lr_scheduler = prepared_result if accelerator.is_main_process: accelerator.init_trackers("textual_inversion_training") # Tensorboard uses markdown formatting, so we wrap the config json in a code block. accelerator.log({"configuration": f"```json\n{json.dumps(config.dict(), indent=2, default=str)}\n```\n"}) checkpoint_tracker = CheckpointTracker( base_dir=ckpt_dir, prefix="checkpoint", extension=".safetensors", max_checkpoints=config.max_checkpoints, ) # Train! total_batch_size = config.train_batch_size * accelerator.num_processes * config.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num batches = {len(data_loader)}") logger.info(f" Instantaneous batch size per device = {config.train_batch_size}") logger.info(f" Gradient accumulation steps = {config.gradient_accumulation_steps}") logger.info(f" Parallel processes = {accelerator.num_processes}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Total optimization steps = {num_train_steps}") logger.info(f" Total epochs = {num_train_epochs}") global_step = 0 first_epoch = 0 completed_epochs = 0 progress_bar = tqdm( range(global_step, num_train_steps), # Only show the progress bar once on each machine. disable=not accelerator.is_local_main_process, ) progress_bar.set_description("Steps") # Keep original embeddings as reference. orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone() def save_checkpoint(num_completed_epochs: int, num_completed_steps: int): accelerator.wait_for_everyone() if accelerator.is_main_process: _save_ti_embeddings( epoch=num_completed_epochs, step=num_completed_steps, text_encoder=text_encoder, placeholder_token_ids=placeholder_token_ids, accelerator=accelerator, logger=logger, checkpoint_tracker=checkpoint_tracker, callbacks=callbacks, ) accelerator.wait_for_everyone() def validate(num_completed_epochs: int, num_completed_steps: int): accelerator.wait_for_everyone() if accelerator.is_main_process: generate_validation_images_sd( epoch=num_completed_epochs, step=num_completed_steps, out_dir=out_dir, accelerator=accelerator, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, noise_scheduler=noise_scheduler, unet=unet, config=config, logger=logger, callbacks=callbacks, ) accelerator.wait_for_everyone() for epoch in range(first_epoch, num_train_epochs): text_encoder.train() train_loss = 0.0 for data_batch_idx, data_batch in enumerate(data_loader): with accelerator.accumulate(text_encoder): loss = train_forward( config=config, data_batch=data_batch, vae=vae, noise_scheduler=noise_scheduler, tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, weight_dtype=weight_dtype, use_masks=config.use_masks, min_snr_gamma=config.min_snr_gamma, ) # Gather the losses across all processes for logging (if we use distributed training). # TODO(ryand): Test that this works properly with distributed training. avg_loss = accelerator.gather(loss.repeat(config.train_batch_size)).mean() train_loss += avg_loss.item() / config.gradient_accumulation_steps accelerator.backward(loss) if accelerator.sync_gradients and config.max_grad_norm is not None: # TODO(ryand): I copied this from another pipeline. Should probably just clip the trainable params. params_to_clip = text_encoder.parameters() accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) # Make sure we don't update any embedding weights besides the newly-added token(s). # TODO(ryand): Should we only do this if accelerator.sync_gradients? restore_original_embeddings( tokenizer=tokenizer, placeholder_token_ids=placeholder_token_ids, accelerator=accelerator, text_encoder=text_encoder, orig_embeds_params=orig_embeds_params, ) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 completed_epochs = epoch if (data_batch_idx + 1) < len(data_loader) else epoch + 1 log = {"train_loss": train_loss, "lr": lr_scheduler.get_last_lr()[0]} if config.optimizer.optimizer_type == "Prodigy": # TODO(ryand): Test Prodigy logging. log["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] accelerator.log(log, step=global_step) train_loss = 0.0 # global_step represents the *number of completed steps* at this point. if config.save_every_n_steps is not None and global_step % config.save_every_n_steps == 0: save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step) if ( config.validate_every_n_steps is not None and global_step % config.validate_every_n_steps == 0 and len(config.validation_prompts) > 0 ): validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step) logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if global_step >= num_train_steps: break # Save a checkpoint every n epochs. if config.save_every_n_epochs is not None and completed_epochs % config.save_every_n_epochs == 0: save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step) # Generate validation images every n epochs. if ( config.validate_every_n_epochs is not None and completed_epochs % config.validate_every_n_epochs == 0 and len(config.validation_prompts) > 0 ): validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step) accelerator.end_training() ================================================ FILE: src/invoke_training/pipelines/stable_diffusion_xl/__init__.py ================================================ ================================================ FILE: src/invoke_training/pipelines/stable_diffusion_xl/finetune/__init__.py ================================================ ================================================ FILE: src/invoke_training/pipelines/stable_diffusion_xl/finetune/config.py ================================================ from typing import Annotated, Literal, Union from pydantic import Field, model_validator from invoke_training.config.base_pipeline_config import BasePipelineConfig from invoke_training.config.data.data_loader_config import DreamboothSDDataLoaderConfig, ImageCaptionSDDataLoaderConfig from invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig, ProdigyOptimizerConfig class SdxlFinetuneConfig(BasePipelineConfig): type: Literal["SDXL_FINETUNE"] = "SDXL_FINETUNE" model: str = "stabilityai/stable-diffusion-xl-base-1.0" """Name or path of the base model to train. Can be in diffusers format, or a single stable diffusion checkpoint file. (E.g. 'stabilityai/stable-diffusion-xl-base-1.0', '/path/to/JuggernautXL.safetensors', etc. ) """ hf_variant: str | None = "fp16" """The Hugging Face Hub model variant to use. Only applies if `model` is a Hugging Face Hub model name. """ save_checkpoint_format: Literal["full_diffusers", "trained_only_diffusers"] = "trained_only_diffusers" """The save format for the checkpoints. Options: - `full_diffusers`: Save the full model in diffusers format (including models that weren't finetuned). If you want a single output artifact that can be used for generation, then this is the recommended option. - `trained_only_diffusers`: Save only the models that were finetuned in diffusers format. For example, if only the UNet model was trained, then only the UNet model will be saved. This option will significantly reduce the disk space consumed by the saved checkpoints. If you plan to extract a LoRA from the fine-tuned model, then this is the recommended option. """ save_dtype: Literal["float32", "float16", "bfloat16"] = "float16" """The dtype to use when saving the model. """ optimizer: AdamOptimizerConfig | ProdigyOptimizerConfig = AdamOptimizerConfig() lr_scheduler: Literal[ "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup" ] = "constant" lr_warmup_steps: int = 0 """The number of warmup steps in the learning rate scheduler. Only applied to schedulers that support warmup. See lr_scheduler. """ min_snr_gamma: float | None = 5.0 """Min-SNR weighting for diffusion training was introduced in https://arxiv.org/abs/2303.09556. This strategy improves the speed of training convergence by adjusting the weight of each sample. `min_snr_gamma` acts like an an upper bound on the weight of samples with low noise levels. If `None`, then Min-SNR weighting will not be applied. If enabled, the recommended value is `min_snr_gamma = 5.0`. """ cache_text_encoder_outputs: bool = False """If True, the text encoder(s) will be applied to all of the captions in the dataset before starting training and the results will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the text encoders in VRAM), and speeds up training (don't have to run the text encoders for each training example). This option can only be enabled if `train_text_encoder == False` and there are no caption augmentations being applied. """ cache_vae_outputs: bool = False """If True, the VAE will be applied to all of the images in the dataset before starting training and the results will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the VAE in VRAM), and speeds up training (don't have to run the VAE encoding step). This option can only be enabled if all non-deterministic image augmentations are disabled (i.e. center_crop=True, random_flip=False). """ enable_cpu_offload_during_validation: bool = False """If True, models will be kept in CPU memory and loaded into GPU memory one-by-one while generating validation images. This reduces VRAM requirements at the cost of slower generation of validation images. """ gradient_accumulation_steps: int = 1 """The number of gradient steps to accumulate before each weight update. This value is passed to Hugging Face Accelerate. This is an alternative to increasing the batch size when training with limited VRAM. """ weight_dtype: Literal["float32", "float16", "bfloat16"] = "bfloat16" """All weights (trainable and fixed) will be cast to this precision. Lower precision dtypes require less VRAM, and result in faster training, but are more prone to issues with numerical stability. Recommendations: - `"float32"`: Use this mode if you have plenty of VRAM available. - `"bfloat16"`: Use this mode if you have limited VRAM and a GPU that supports bfloat16. - `"float16"`: Use this mode if you have limited VRAM and a GPU that does not support bfloat16. See also [`mixed_precision`][invoke_training.pipelines.stable_diffusion_xl.lora.config.SdxlLoraConfig.mixed_precision]. """ # noqa: E501 mixed_precision: Literal["no", "fp16", "bf16", "fp8"] = "no" """The mixed precision mode to use. If mixed precision is enabled, then all non-trainable parameters will be cast to the specified `weight_dtype`, and trainable parameters are kept in float32 precision to avoid issues with numerical stability. This value is passed to Hugging Face Accelerate. See [`accelerate.Accelerator.mixed_precision`](https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.mixed_precision) for more details. """ # noqa: E501 xformers: bool = False """If true, use xformers for more efficient attention blocks. """ gradient_checkpointing: bool = False """Whether or not to use gradient checkpointing to save memory at the expense of a slower backward pass. Enabling gradient checkpointing slows down training by ~20%. """ max_checkpoints: int | None = None """The maximum number of checkpoints to keep. New checkpoints will replace earlier checkpoints to stay under this limit. Note that this limit is applied to 'step' and 'epoch' checkpoints separately. """ prediction_type: Literal["epsilon", "v_prediction"] | None = None """The prediction_type that will be used for training. Choose between 'epsilon' or 'v_prediction' or leave 'None'. If 'None', the prediction type of the scheduler: `noise_scheduler.config.prediction_type` is used. """ max_grad_norm: float | None = None """Max gradient norm for clipping. Set to null or 0 for no clipping. """ validation_prompts: list[str] = [] """A list of prompts that will be used to generate images throughout training for the purpose of tracking progress. See also 'validate_every_n_epochs'. """ negative_validation_prompts: list[str] | None = None """A list of negative prompts that will be applied when generating validation images. If set, this list should have the same length as 'validation_prompts'. """ num_validation_images_per_prompt: int = 4 """The number of validation images to generate for each prompt in 'validation_prompts'. Careful, validation can become quite slow if this number is too large. """ train_batch_size: int = 4 """The training batch size. """ use_masks: bool = False """If True, image masks will be applied to weight the loss during training. The dataset must contain masks for this feature to be used. """ data_loader: Annotated[ Union[ImageCaptionSDDataLoaderConfig, DreamboothSDDataLoaderConfig], Field(discriminator="type") ] vae_model: str | None = None """The name of the Hugging Face Hub VAE model to train against. This will override the VAE bundled with the base model (specified by the `model` parameter). This config option is provided for SDXL models, because SDXL shipped with a VAE that produces NaNs in fp16 mode, so it is common to replace this VAE with a fixed version. """ @model_validator(mode="after") def check_validation_prompts(self): if self.negative_validation_prompts is not None and len(self.negative_validation_prompts) != len( self.validation_prompts ): raise ValueError( f"The number of validation_prompts ({len(self.validation_prompts)}) must match the number of " f"negative_validation_prompts ({len(self.negative_validation_prompts)})." ) return self ================================================ FILE: src/invoke_training/pipelines/stable_diffusion_xl/finetune/train.py ================================================ import itertools import json import logging import math import os import tempfile import time from typing import Literal import peft import torch import torch.utils.data from accelerate.utils import set_seed from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel from diffusers.optimization import get_scheduler from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from invoke_training._shared.accelerator.accelerator_utils import ( get_dtype_from_str, initialize_accelerator, initialize_logging, ) from invoke_training._shared.checkpoints.checkpoint_tracker import CheckpointTracker from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import log_aspect_ratio_buckets from invoke_training._shared.optimizer.optimizer_utils import initialize_optimizer from invoke_training._shared.stable_diffusion.checkpoint_utils import ( save_sdxl_diffusers_checkpoint, save_sdxl_diffusers_unet_checkpoint, ) from invoke_training._shared.stable_diffusion.model_loading_utils import load_models_sdxl from invoke_training._shared.stable_diffusion.validation import generate_validation_images_sdxl from invoke_training._shared.utils.import_xformers import import_xformers from invoke_training.pipelines.callbacks import ModelCheckpoint, ModelType, PipelineCallbacks, TrainingCheckpoint from invoke_training.pipelines.stable_diffusion.lora.train import cache_vae_outputs from invoke_training.pipelines.stable_diffusion_xl.finetune.config import SdxlFinetuneConfig from invoke_training.pipelines.stable_diffusion_xl.lora.train import ( _build_data_loader, cache_text_encoder_outputs, train_forward, ) def _save_sdxl_checkpoint( epoch: int, step: int, save_checkpoint_format: Literal["full_diffusers", "trained_only_diffusers"], vae: AutoencoderKL, text_encoder_1: CLIPTextModel, text_encoder_2: CLIPTextModel, tokenizer_1: CLIPTokenizer, tokenizer_2: CLIPTokenizer, noise_scheduler: DDPMScheduler, unet: UNet2DConditionModel, save_dtype: torch.dtype, logger: logging.Logger, checkpoint_tracker: CheckpointTracker, callbacks: list[PipelineCallbacks] | None, ): # Prune checkpoints and get new checkpoint path. num_pruned = checkpoint_tracker.prune(1) if num_pruned > 0: logger.info(f"Pruned {num_pruned} checkpoint(s).") save_path = checkpoint_tracker.get_path(epoch=epoch, step=step) if save_checkpoint_format == "trained_only_diffusers": model_type = ModelType.SDXL_UNET_DIFFUSERS save_sdxl_diffusers_unet_checkpoint(checkpoint_path=save_path, unet=unet, save_dtype=save_dtype) elif save_checkpoint_format == "full_diffusers": model_type = ModelType.SDXL_FULL_DIFFUSERS save_sdxl_diffusers_checkpoint( checkpoint_path=save_path, vae=vae, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2, tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2, noise_scheduler=noise_scheduler, unet=unet, save_dtype=save_dtype, ) else: raise ValueError(f"Invalid save_checkpoint_format: '{save_checkpoint_format}'.") if callbacks is not None: for cb in callbacks: cb.on_save_checkpoint( TrainingCheckpoint( models=[ModelCheckpoint(file_path=save_path, model_type=model_type)], epoch=epoch, step=step, ) ) def train(config: SdxlFinetuneConfig, callbacks: list[PipelineCallbacks] | None = None): # noqa: C901 # Give a clear error message if an unsupported base model was chosen. # TODO(ryan): Update this check to work with single-file SD checkpoints. # check_base_model_version( # {BaseModelVersionEnum.STABLE_DIFFUSION_SDXL_BASE}, # config.model, # local_files_only=False, # ) # Create a timestamped directory for all outputs. out_dir = os.path.join(config.base_output_dir, f"{time.time()}") ckpt_dir = os.path.join(out_dir, "checkpoints") os.makedirs(ckpt_dir) accelerator = initialize_accelerator( out_dir, config.gradient_accumulation_steps, config.mixed_precision, config.report_to ) logger = initialize_logging(os.path.basename(__file__), accelerator) # Set the accelerate seed. if config.seed is not None: set_seed(config.seed) # Log the accelerator configuration from every process to help with debugging. logger.info(accelerator.state, main_process_only=False) logger.info("Starting Training.") logger.info(f"Configuration:\n{json.dumps(config.dict(), indent=2, default=str)}") logger.info(f"Output dir: '{out_dir}'") # Write the configuration to disk. with open(os.path.join(out_dir, "config.json"), "w") as f: json.dump(config.dict(), f, indent=2, default=str) weight_dtype = get_dtype_from_str(config.weight_dtype) logger.info("Loading models.") tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, unet = load_models_sdxl( logger=logger, model_name_or_path=config.model, hf_variant=config.hf_variant, vae_model=config.vae_model, base_embeddings=None, dtype=weight_dtype, ) if config.xformers: import_xformers() # TODO(ryand): There is a known issue if xformers is enabled when training in mixed precision where xformers # will fail because Q, K, V have different dtypes. unet.enable_xformers_memory_efficient_attention() vae.enable_xformers_memory_efficient_attention() # Prepare text encoder output cache. text_encoder_output_cache_dir_name = None if config.cache_text_encoder_outputs: # TODO(ryand): Think about how to better check if it is safe to cache the text encoder outputs. Currently, there # are a number of configurations that would cause variation in the text encoder outputs and should not be used # with caching. # We use a temporary directory for the cache. The directory will automatically be cleaned up when # tmp_text_encoder_output_cache_dir is destroyed. tmp_text_encoder_output_cache_dir = tempfile.TemporaryDirectory() text_encoder_output_cache_dir_name = tmp_text_encoder_output_cache_dir.name if accelerator.is_local_main_process: # Only the main process should populate the cache. logger.info(f"Generating text encoder output cache ('{text_encoder_output_cache_dir_name}').") text_encoder_1.to(accelerator.device, dtype=weight_dtype) text_encoder_2.to(accelerator.device, dtype=weight_dtype) # TODO(ryan): Move cache_text_encoder_outputs to a shared location so that it is not imported from another # pipeline. cache_text_encoder_outputs( text_encoder_output_cache_dir_name, config, tokenizer_1, tokenizer_2, text_encoder_1, text_encoder_2 ) # Move the text_encoders back to the CPU, because they are not needed for training. text_encoder_1.to("cpu") text_encoder_2.to("cpu") accelerator.wait_for_everyone() else: text_encoder_1.to(accelerator.device, dtype=weight_dtype) text_encoder_2.to(accelerator.device, dtype=weight_dtype) # Prepare VAE output cache. vae_output_cache_dir_name = None if config.cache_vae_outputs: if config.data_loader.random_flip: raise ValueError("'cache_vae_outputs' cannot be True if 'random_flip' is True.") if not config.data_loader.center_crop: raise ValueError("'cache_vae_outputs' cannot be True if 'center_crop' is False.") # We use a temporary directory for the cache. The directory will automatically be cleaned up when # tmp_vae_output_cache_dir is destroyed. tmp_vae_output_cache_dir = tempfile.TemporaryDirectory() vae_output_cache_dir_name = tmp_vae_output_cache_dir.name if accelerator.is_local_main_process: # Only the main process should to populate the cache. logger.info(f"Generating VAE output cache ('{vae_output_cache_dir_name}').") vae.to(accelerator.device, dtype=weight_dtype) # TODO(ryan): Move cache_text_encoder_outputs to a shared location so that it is not imported from another # pipeline. data_loader = _build_data_loader( data_loader_config=config.data_loader, batch_size=config.train_batch_size, use_masks=config.use_masks, shuffle=False, sequential_batching=True, ) cache_vae_outputs(vae_output_cache_dir_name, data_loader, vae) # Move the VAE back to the CPU, because it is not needed for training. vae.to("cpu") accelerator.wait_for_everyone() else: vae.to(accelerator.device, dtype=weight_dtype) unet.to(accelerator.device, dtype=weight_dtype) # Make UNet trainable. unet.requires_grad_(True) unet.train() all_trainable_models = [unet] # If mixed_precision is enabled, cast all trainable params to float32. if config.mixed_precision != "no": for trainable_model in all_trainable_models: for param in trainable_model.parameters(): if param.requires_grad: param.data = param.to(torch.float32) if config.gradient_checkpointing: # We want to enable gradient checkpointing in the UNet regardless of whether it is being trained. unet.enable_gradient_checkpointing() # unet must be in train() mode for gradient checkpointing to take effect. # At the time of writing, the unet dropout probabilities default to 0, so putting the unet in train mode does # not change its forward behavior. unet.train() optimizer = initialize_optimizer(config.optimizer, unet.parameters()) data_loader = _build_data_loader( data_loader_config=config.data_loader, batch_size=config.train_batch_size, use_masks=config.use_masks, text_encoder_output_cache_dir=text_encoder_output_cache_dir_name, vae_output_cache_dir=vae_output_cache_dir_name, ) log_aspect_ratio_buckets(logger=logger, batch_sampler=data_loader.batch_sampler) assert sum([config.max_train_steps is not None, config.max_train_epochs is not None]) == 1 assert sum([config.save_every_n_steps is not None, config.save_every_n_epochs is not None]) == 1 assert sum([config.validate_every_n_steps is not None, config.validate_every_n_epochs is not None]) == 1 # A "step" represents a single weight update operation (i.e. takes into account gradient accumulation steps). # math.ceil(...) is used in calculating the num_steps_per_epoch, because by default an optimizer step is taken when # the end of the dataloader is reached, even if gradient_accumulation_steps hasn't been reached. num_steps_per_epoch = math.ceil(len(data_loader) / config.gradient_accumulation_steps) num_train_steps = config.max_train_steps or config.max_train_epochs * num_steps_per_epoch num_train_epochs = math.ceil(num_train_steps / num_steps_per_epoch) # TODO(ryand): Test in a distributed training environment and more clearly document the rationale for scaling steps # by the number of processes. This scaling logic was copied from the diffusers example training code, but it appears # in many places so I don't know where it originated. Internally, accelerate makes one LR scheduler step per process # (https://github.com/huggingface/accelerate/blame/49cb83a423f2946059117d8bb39b7c8747d29d80/src/accelerate/scheduler.py#L72-L82), # so the scaling here simply reverses that behaviour. lr_scheduler: torch.optim.lr_scheduler.LRScheduler = get_scheduler( config.lr_scheduler, optimizer=optimizer, num_warmup_steps=config.lr_warmup_steps * accelerator.num_processes, num_training_steps=num_train_steps * accelerator.num_processes, ) prepared_result: tuple[ UNet2DConditionModel, peft.PeftModel | CLIPTextModel, peft.PeftModel | CLIPTextModel, torch.optim.Optimizer, torch.utils.data.DataLoader, torch.optim.lr_scheduler.LRScheduler, ] = accelerator.prepare( unet, text_encoder_1, text_encoder_2, optimizer, data_loader, lr_scheduler, # Disable automatic device placement for text_encoder if the text encoder outputs were cached. device_placement=[ True, not config.cache_text_encoder_outputs, not config.cache_text_encoder_outputs, True, True, True, ], ) unet, text_encoder_1, text_encoder_2, optimizer, data_loader, lr_scheduler = prepared_result if accelerator.is_main_process: accelerator.init_trackers("finetune") # Tensorboard uses markdown formatting, so we wrap the config json in a code block. accelerator.log({"configuration": f"```json\n{json.dumps(config.dict(), indent=2, default=str)}\n```\n"}) checkpoint_tracker = CheckpointTracker( base_dir=ckpt_dir, prefix="checkpoint", max_checkpoints=config.max_checkpoints ) # Train! total_batch_size = config.train_batch_size * accelerator.num_processes * config.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num batches = {len(data_loader)}") logger.info(f" Instantaneous batch size per device = {config.train_batch_size}") logger.info(f" Gradient accumulation steps = {config.gradient_accumulation_steps}") logger.info(f" Parallel processes = {accelerator.num_processes}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Total optimization steps = {num_train_steps}") logger.info(f" Total epochs = {num_train_epochs}") global_step = 0 first_epoch = 0 completed_epochs = 0 progress_bar = tqdm( range(global_step, num_train_steps), # Only show the progress bar once on each machine. disable=not accelerator.is_local_main_process, ) progress_bar.set_description("Steps") def save_checkpoint(num_completed_epochs: int, num_completed_steps: int): accelerator.wait_for_everyone() if accelerator.is_main_process: _save_sdxl_checkpoint( epoch=num_completed_epochs, step=num_completed_steps, save_checkpoint_format=config.save_checkpoint_format, vae=vae, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2, tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2, noise_scheduler=noise_scheduler, unet=unet, save_dtype=get_dtype_from_str(config.save_dtype), logger=logger, checkpoint_tracker=checkpoint_tracker, callbacks=callbacks, ) accelerator.wait_for_everyone() def validate(num_completed_epochs: int, num_completed_steps: int): accelerator.wait_for_everyone() if accelerator.is_main_process: generate_validation_images_sdxl( epoch=num_completed_epochs, step=num_completed_steps, out_dir=out_dir, accelerator=accelerator, vae=vae, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2, tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2, noise_scheduler=noise_scheduler, unet=unet, config=config, logger=logger, callbacks=callbacks, ) accelerator.wait_for_everyone() for epoch in range(first_epoch, num_train_epochs): train_loss = 0.0 for data_batch_idx, data_batch in enumerate(data_loader): with accelerator.accumulate(unet, text_encoder_1, text_encoder_2): loss = train_forward( accelerator=accelerator, data_batch=data_batch, vae=vae, noise_scheduler=noise_scheduler, tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2, unet=unet, weight_dtype=weight_dtype, resolution=config.data_loader.resolution, use_masks=config.use_masks, prediction_type=config.prediction_type, min_snr_gamma=config.min_snr_gamma, ) # Gather the losses across all processes for logging (if we use distributed training). # TODO(ryand): Test that this works properly with distributed training. avg_loss = accelerator.gather(loss.repeat(config.train_batch_size)).mean() train_loss += avg_loss.item() / config.gradient_accumulation_steps # Backpropagate. accelerator.backward(loss) if accelerator.sync_gradients and config.max_grad_norm is not None: params_to_clip = itertools.chain.from_iterable([m.parameters() for m in all_trainable_models]) accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) # Checks if the accelerator has performed an optimization step behind the scenes. if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 completed_epochs = epoch if (data_batch_idx + 1) < len(data_loader) else epoch + 1 log = {"train_loss": train_loss} lrs = lr_scheduler.get_last_lr() # When training the UNet, it will always be the first parameter group. log["lr/unet"] = float(lrs[0]) if config.optimizer.optimizer_type == "Prodigy": log["lr/d*lr/unet"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] accelerator.log(log, step=global_step) train_loss = 0.0 # global_step represents the *number of completed steps* at this point. if config.save_every_n_steps is not None and global_step % config.save_every_n_steps == 0: save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step) if ( config.validate_every_n_steps is not None and global_step % config.validate_every_n_steps == 0 and len(config.validation_prompts) > 0 ): validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step) logs = { "step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], } progress_bar.set_postfix(**logs) if global_step >= num_train_steps: break # Save a checkpoint every n epochs. if config.save_every_n_epochs is not None and completed_epochs % config.save_every_n_epochs == 0: save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step) # Generate validation images every n epochs. if ( config.validate_every_n_epochs is not None and completed_epochs % config.validate_every_n_epochs == 0 and len(config.validation_prompts) > 0 ): validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step) accelerator.end_training() ================================================ FILE: src/invoke_training/pipelines/stable_diffusion_xl/lora/__init__.py ================================================ ================================================ FILE: src/invoke_training/pipelines/stable_diffusion_xl/lora/config.py ================================================ from typing import Annotated, Literal, Union from pydantic import Field, model_validator from invoke_training._shared.stable_diffusion.lora_checkpoint_utils import ( TEXT_ENCODER_TARGET_MODULES, UNET_TARGET_MODULES, ) from invoke_training.config.base_pipeline_config import BasePipelineConfig from invoke_training.config.data.data_loader_config import DreamboothSDDataLoaderConfig, ImageCaptionSDDataLoaderConfig from invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig, ProdigyOptimizerConfig class SdxlLoraConfig(BasePipelineConfig): type: Literal["SDXL_LORA"] = "SDXL_LORA" model: str = "stabilityai/stable-diffusion-xl-base-1.0" """Name or path of the base model to train. Can be in diffusers format, or a single stable diffusion checkpoint file. (E.g. 'stabilityai/stable-diffusion-xl-base-1.0', '/path/to/JuggernautXL.safetensors', etc. ) """ hf_variant: str | None = "fp16" """The Hugging Face Hub model variant to use. Only applies if `model` is a Hugging Face Hub model name. """ # Note: Pydantic handles mutable default values well: # https://docs.pydantic.dev/latest/concepts/models/#fields-with-non-hashable-default-values base_embeddings: dict[str, str] = {} """A mapping of embedding tokens to trained embedding file paths. These embeddings will be applied to the base model before training. Example: ``` base_embeddings = { "bruce_the_gnome": "/path/to/bruce_the_gnome.safetensors", } ``` Consider also adding the embedding tokens to the `data_loader.caption_prefix` if they are not already present in the dataset captions. Note that the embeddings themselves are not fine-tuned further, but they will impact the LoRA model training if they are referenced in the dataset captions. The list of embeddings provided here should be the same list used at generation time with the resultant LoRA model. """ lora_checkpoint_format: Literal["invoke_peft", "kohya"] = "kohya" """The format of the LoRA checkpoint to save. Choose between `invoke_peft` or `kohya`.""" train_unet: bool = True """Whether to add LoRA layers to the UNet model and train it. """ train_text_encoder: bool = True """Whether to add LoRA layers to the text encoder and train it. """ optimizer: AdamOptimizerConfig | ProdigyOptimizerConfig = AdamOptimizerConfig() text_encoder_learning_rate: float | None = None """The learning rate to use for the text encoder model. If set, this overrides the optimizer's default learning rate. Set to null or 0 to use the optimizer's default learning rate. """ unet_learning_rate: float | None = None """The learning rate to use for the UNet model. If set, this overrides the optimizer's default learning rate. Set to null or 0 to use the optimizer's default learning rate. """ lr_scheduler: Literal[ "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup" ] = "constant" lr_warmup_steps: int = 0 """The number of warmup steps in the learning rate scheduler. Only applied to schedulers that support warmup. See lr_scheduler. """ min_snr_gamma: float | None = 5.0 """Min-SNR weighting for diffusion training was introduced in https://arxiv.org/abs/2303.09556. This strategy improves the speed of training convergence by adjusting the weight of each sample. `min_snr_gamma` acts like an an upper bound on the weight of samples with low noise levels. If `None`, then Min-SNR weighting will not be applied. If enabled, the recommended value is `min_snr_gamma = 5.0`. """ lora_rank_dim: int = 4 """The rank dimension to use for the LoRA layers. Increasing the rank dimension increases the model's expressivity, but also increases the size of the generated LoRA model. """ # The default list of target modules is based on # https://github.com/huggingface/peft/blob/8665e2b5719faa4e4b91749ddec09442927b53e0/examples/stable_diffusion/train_dreambooth.py#L49C1-L65C87 unet_lora_target_modules: list[str] = UNET_TARGET_MODULES """The list of target modules to apply LoRA layers to in the UNet model. The default list will produce a highly expressive LoRA model. For a smaller and less expressive LoRA model, the following list is recommended: ```python unet_lora_target_modules = ["to_k", "to_q", "to_v", "to_out.0"] ``` The list of target modules is passed to Hugging Face's PEFT library. See [the docs](https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.target_modules) for details. """ text_encoder_lora_target_modules: list[str] = TEXT_ENCODER_TARGET_MODULES """The list of target modules to apply LoRA layers to in the text encoder models. The default list will produce a highly expressive LoRA model. For a smaller and less expressive LoRA model, the following list is recommended: ```python text_encoder_lora_target_modules = ["fc1", "fc2", "q_proj", "k_proj", "v_proj", "out_proj"] ``` The list of target modules is passed to Hugging Face's PEFT library. See [the docs](https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.target_modules) for details. """ cache_text_encoder_outputs: bool = False """If True, the text encoder(s) will be applied to all of the captions in the dataset before starting training and the results will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the text encoders in VRAM), and speeds up training (don't have to run the text encoders for each training example). This option can only be enabled if `train_text_encoder == False` and there are no caption augmentations being applied. """ cache_vae_outputs: bool = False """If True, the VAE will be applied to all of the images in the dataset before starting training and the results will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the VAE in VRAM), and speeds up training (don't have to run the VAE encoding step). This option can only be enabled if all non-deterministic image augmentations are disabled (i.e. center_crop=True, random_flip=False). """ enable_cpu_offload_during_validation: bool = False """If True, models will be kept in CPU memory and loaded into GPU memory one-by-one while generating validation images. This reduces VRAM requirements at the cost of slower generation of validation images. """ gradient_accumulation_steps: int = 1 """The number of gradient steps to accumulate before each weight update. This value is passed to Hugging Face Accelerate. This is an alternative to increasing the batch size when training with limited VRAM. """ weight_dtype: Literal["float32", "float16", "bfloat16"] = "bfloat16" """All weights (trainable and fixed) will be cast to this precision. Lower precision dtypes require less VRAM, and result in faster training, but are more prone to issues with numerical stability. Recommendations: - `"float32"`: Use this mode if you have plenty of VRAM available. - `"bfloat16"`: Use this mode if you have limited VRAM and a GPU that supports bfloat16. - `"float16"`: Use this mode if you have limited VRAM and a GPU that does not support bfloat16. See also [`mixed_precision`][invoke_training.pipelines.stable_diffusion_xl.lora.config.SdxlLoraConfig.mixed_precision]. """ # noqa: E501 mixed_precision: Literal["no", "fp16", "bf16", "fp8"] = "no" """The mixed precision mode to use. If mixed precision is enabled, then all non-trainable parameters will be cast to the specified `weight_dtype`, and trainable parameters are kept in float32 precision to avoid issues with numerical stability. This value is passed to Hugging Face Accelerate. See [`accelerate.Accelerator.mixed_precision`](https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.mixed_precision) for more details. """ # noqa: E501 xformers: bool = False """If true, use xformers for more efficient attention blocks. """ gradient_checkpointing: bool = False """Whether or not to use gradient checkpointing to save memory at the expense of a slower backward pass. Enabling gradient checkpointing slows down training by ~20%. """ max_checkpoints: int | None = None """The maximum number of checkpoints to keep. New checkpoints will replace earlier checkpoints to stay under this limit. Note that this limit is applied to 'step' and 'epoch' checkpoints separately. """ prediction_type: Literal["epsilon", "v_prediction"] | None = None """The prediction_type that will be used for training. Choose between 'epsilon' or 'v_prediction' or leave 'None'. If 'None', the prediction type of the scheduler: `noise_scheduler.config.prediction_type` is used. """ max_grad_norm: float | None = None """Max gradient norm for clipping. Set to null or 0 for no clipping. """ validation_prompts: list[str] = [] """A list of prompts that will be used to generate images throughout training for the purpose of tracking progress. See also 'validate_every_n_epochs'. """ negative_validation_prompts: list[str] | None = None """A list of negative prompts that will be applied when generating validation images. If set, this list should have the same length as 'validation_prompts'. """ num_validation_images_per_prompt: int = 4 """The number of validation images to generate for each prompt in 'validation_prompts'. Careful, validation can become quite slow if this number is too large. """ train_batch_size: int = 4 """The training batch size. """ use_masks: bool = False """If True, image masks will be applied to weight the loss during training. The dataset must contain masks for this feature to be used. """ data_loader: Annotated[ Union[ImageCaptionSDDataLoaderConfig, DreamboothSDDataLoaderConfig], Field(discriminator="type") ] vae_model: str | None = None """The name of the Hugging Face Hub VAE model to train against. This will override the VAE bundled with the base model (specified by the `model` parameter). This config option is provided for SDXL models, because SDXL shipped with a VAE that produces NaNs in fp16 mode, so it is common to replace this VAE with a fixed version. """ @model_validator(mode="after") def check_validation_prompts(self): if self.negative_validation_prompts is not None and len(self.negative_validation_prompts) != len( self.validation_prompts ): raise ValueError( f"The number of validation_prompts ({len(self.validation_prompts)}) must match the number of " f"negative_validation_prompts ({len(self.negative_validation_prompts)})." ) return self ================================================ FILE: src/invoke_training/pipelines/stable_diffusion_xl/lora/train.py ================================================ import itertools import json import logging import math import os import tempfile import time from pathlib import Path from typing import Literal, Optional, Union import peft import torch import torch.utils.data from accelerate import Accelerator from accelerate.utils import set_seed from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel from diffusers.optimization import get_scheduler from torch.utils.data import DataLoader from tqdm.auto import tqdm from transformers import CLIPPreTrainedModel, CLIPTextModel, PreTrainedTokenizer from invoke_training._shared.accelerator.accelerator_utils import ( get_dtype_from_str, initialize_accelerator, initialize_logging, ) from invoke_training._shared.checkpoints.checkpoint_tracker import CheckpointTracker from invoke_training._shared.data.data_loaders.dreambooth_sd_dataloader import build_dreambooth_sd_dataloader from invoke_training._shared.data.data_loaders.image_caption_sd_dataloader import build_image_caption_sd_dataloader from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import log_aspect_ratio_buckets from invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache from invoke_training._shared.data.utils.resolution import Resolution from invoke_training._shared.optimizer.optimizer_utils import initialize_optimizer from invoke_training._shared.stable_diffusion.lora_checkpoint_utils import ( save_sdxl_kohya_checkpoint, save_sdxl_peft_checkpoint, ) from invoke_training._shared.stable_diffusion.min_snr_weighting import compute_snr from invoke_training._shared.stable_diffusion.model_loading_utils import load_models_sdxl from invoke_training._shared.stable_diffusion.tokenize_captions import tokenize_captions from invoke_training._shared.stable_diffusion.validation import generate_validation_images_sdxl from invoke_training._shared.utils.import_xformers import import_xformers from invoke_training.config.data.data_loader_config import DreamboothSDDataLoaderConfig, ImageCaptionSDDataLoaderConfig from invoke_training.pipelines.callbacks import ModelCheckpoint, ModelType, PipelineCallbacks, TrainingCheckpoint from invoke_training.pipelines.stable_diffusion.lora.train import cache_vae_outputs from invoke_training.pipelines.stable_diffusion_xl.lora.config import SdxlLoraConfig def _save_sdxl_lora_checkpoint( epoch: int, step: int, unet: peft.PeftModel | None, text_encoder_1: peft.PeftModel | None, text_encoder_2: peft.PeftModel | None, logger: logging.Logger, checkpoint_tracker: CheckpointTracker, lora_checkpoint_format: Literal["invoke_peft", "kohya"], callbacks: list[PipelineCallbacks] | None, ): # Prune checkpoints and get new checkpoint path. num_pruned = checkpoint_tracker.prune(1) if num_pruned > 0: logger.info(f"Pruned {num_pruned} checkpoint(s).") save_path = checkpoint_tracker.get_path(epoch=epoch, step=step) if lora_checkpoint_format == "invoke_peft": model_type = ModelType.SD1_LORA_PEFT save_sdxl_peft_checkpoint( Path(save_path), unet=unet, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2 ) elif lora_checkpoint_format == "kohya": model_type = ModelType.SD1_LORA_KOHYA save_sdxl_kohya_checkpoint( Path(save_path), unet=unet, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2 ) else: raise ValueError(f"Unsupported lora_checkpoint_format: '{lora_checkpoint_format}'.") if callbacks is not None: for cb in callbacks: cb.on_save_checkpoint( TrainingCheckpoint( models=[ModelCheckpoint(file_path=save_path, model_type=model_type)], epoch=epoch, step=step ) ) def _build_data_loader( data_loader_config: Union[ImageCaptionSDDataLoaderConfig, DreamboothSDDataLoaderConfig], batch_size: int, use_masks: bool = False, text_encoder_output_cache_dir: Optional[str] = None, vae_output_cache_dir: Optional[str] = None, shuffle: bool = True, sequential_batching: bool = False, ) -> DataLoader: if data_loader_config.type == "IMAGE_CAPTION_SD_DATA_LOADER": return build_image_caption_sd_dataloader( config=data_loader_config, batch_size=batch_size, use_masks=use_masks, text_encoder_output_cache_dir=text_encoder_output_cache_dir, text_encoder_cache_field_to_output_field={ "prompt_embeds": "prompt_embeds", "pooled_prompt_embeds": "pooled_prompt_embeds", }, vae_output_cache_dir=vae_output_cache_dir, shuffle=shuffle, ) elif data_loader_config.type == "DREAMBOOTH_SD_DATA_LOADER": if use_masks: raise ValueError("Masks are not yet supported for DreamBooth data loaders.") return build_dreambooth_sd_dataloader( config=data_loader_config, batch_size=batch_size, text_encoder_output_cache_dir=text_encoder_output_cache_dir, text_encoder_cache_field_to_output_field={ "prompt_embeds": "prompt_embeds", "pooled_prompt_embeds": "pooled_prompt_embeds", }, vae_output_cache_dir=vae_output_cache_dir, shuffle=shuffle, sequential_batching=sequential_batching, ) else: raise ValueError(f"Unsupported data loader config type: '{data_loader_config.type}'.") # encode_prompt was adapted from: # https://github.com/huggingface/diffusers/blob/7b07f9812a58bfa96c06ed8ffe9e6b584286e2fd/examples/text_to_image/train_text_to_image_lora_sdxl.py#L470-L496 def _encode_prompt(text_encoders: list[CLIPPreTrainedModel], prompt_token_ids_list: list[torch.Tensor]): prompt_embeds_list = [] for i, text_encoder in enumerate(text_encoders): text_input_ids = prompt_token_ids_list[i] prompt_embeds = text_encoder( text_input_ids.to(text_encoder.device), output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder. # TODO(ryand): Document this logic more clearly. pooled_prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.hidden_states[-2] bs_embed, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) return prompt_embeds, pooled_prompt_embeds # TODO(ryand): Cache VAE outputs and text encoder outputs at the same time in a single pass over the dataset. def cache_text_encoder_outputs( cache_dir: str, config: SdxlLoraConfig, tokenizer_1: PreTrainedTokenizer, tokenizer_2: PreTrainedTokenizer, text_encoder_1: CLIPPreTrainedModel, text_encoder_2: CLIPPreTrainedModel, ): """Run the text encoder on all captions in the dataset and cache the results to disk. Args: cache_dir (str): The directory where the results will be cached. config (FinetuneLoRAConfig): Training config. tokenizer_1 (PreTrainedTokenizer): tokenizer_2 (PreTrainedTokenizer): text_encoder_1 (CLIPPreTrainedModel): text_encoder_2 (CLIPPreTrainedModel): """ data_loader = _build_data_loader( data_loader_config=config.data_loader, batch_size=config.train_batch_size, shuffle=False, sequential_batching=True, ) cache = TensorDiskCache(cache_dir) for data_batch in tqdm(data_loader): caption_token_ids_1 = tokenize_captions(tokenizer_1, data_batch["caption"]) caption_token_ids_2 = tokenize_captions(tokenizer_2, data_batch["caption"]) prompt_embeds, pooled_prompt_embeds = _encode_prompt( [text_encoder_1, text_encoder_2], [caption_token_ids_1, caption_token_ids_2] ) # Split batch before caching. for i in range(len(data_batch["id"])): embeds = { "prompt_embeds": prompt_embeds[i], "pooled_prompt_embeds": pooled_prompt_embeds[i], } cache.save(data_batch["id"][i], embeds) def train_forward( # noqa: C901 accelerator: Accelerator, data_batch: dict, vae: AutoencoderKL, noise_scheduler: DDPMScheduler, tokenizer_1: PreTrainedTokenizer, tokenizer_2: PreTrainedTokenizer, text_encoder_1: CLIPPreTrainedModel, text_encoder_2: CLIPPreTrainedModel, unet: UNet2DConditionModel, weight_dtype: torch.dtype, resolution: int | tuple[int, int], use_masks: bool = False, prediction_type=None, min_snr_gamma: float | None = None, ): """Run the forward training pass for a single data_batch. Returns: torch.Tensor: Loss """ # Convert images to latent space. # The VAE output may have been cached and included in the data_batch. If not, we calculate it here. latents = data_batch.get("vae_output", None) if latents is None: latents = vae.encode(data_batch["image"].to(dtype=weight_dtype)).latent_dist.sample() latents = latents * vae.config.scaling_factor # Sample noise that we'll add to the latents. noise = torch.randn_like(latents) batch_size = latents.shape[0] # Sample a random timestep for each image. timesteps = torch.randint( 0, noise_scheduler.config.num_train_timesteps, (batch_size,), device=latents.device, ) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep (this is the forward diffusion # process). noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # compute_time_ids was copied from: # https://github.com/huggingface/diffusers/blob/7b07f9812a58bfa96c06ed8ffe9e6b584286e2fd/examples/text_to_image/train_text_to_image_lora_sdxl.py#L1033-L1039 # "time_ids" may seem like a weird naming choice. The name comes from the diffusers SDXL implementation. Presumably, # it is a result of the fact that the original size and crop values get concatenated with the time embeddings. def compute_time_ids(original_size, crops_coords_top_left): # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids target_size = Resolution.parse(resolution).to_tuple() add_time_ids = list(original_size + crops_coords_top_left + target_size) add_time_ids = torch.tensor([add_time_ids]) add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype) return add_time_ids add_time_ids = torch.cat( [compute_time_ids(s, c) for s, c in zip(data_batch["original_size_hw"], data_batch["crop_top_left_yx"])] ) unet_conditions = {"time_ids": add_time_ids} # Get the text embedding for conditioning. # The text encoder output may have been cached and included in the data_batch. If not, we calculate it here. if "prompt_embeds" in data_batch: prompt_embeds = data_batch["prompt_embeds"] pooled_prompt_embeds = data_batch["pooled_prompt_embeds"] else: caption_token_ids_1 = tokenize_captions(tokenizer_1, data_batch["caption"]) caption_token_ids_2 = tokenize_captions(tokenizer_2, data_batch["caption"]) prompt_embeds, pooled_prompt_embeds = _encode_prompt( [text_encoder_1, text_encoder_2], [caption_token_ids_1, caption_token_ids_2] ) prompt_embeds = prompt_embeds.to(dtype=weight_dtype) pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=weight_dtype) unet_conditions["text_embeds"] = pooled_prompt_embeds # Get the target for loss depending on the prediction type. if prediction_type is not None: # Set the prediction_type of scheduler if it's defined in config. noise_scheduler.register_to_config(prediction_type=prediction_type) if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") # Predict the noise residual. model_pred = unet(noisy_latents, timesteps, prompt_embeds, added_cond_kwargs=unet_conditions).sample min_snr_weights = None if min_snr_gamma is not None: # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) # Note: We divide by snr here per Section 4.2 of the paper, since we are predicting the noise instead of x_0. # w_t = min(1, SNR(t)) / SNR(t) min_snr_weights = torch.clamp(snr, max=min_snr_gamma) / snr if noise_scheduler.config.prediction_type == "epsilon": pass elif noise_scheduler.config.prediction_type == "v_prediction": # Velocity objective needs to be floored to an SNR weight of one. min_snr_weights = min_snr_weights + 1 else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction="none") if use_masks: # TODO(ryand): As a future performance optimization, we may want to do this resizing in the dataloader. mask = data_batch["mask"].to(dtype=loss.dtype, device=loss.device) _, _, latent_h, latent_w = loss.shape mask = torch.nn.functional.interpolate(mask, size=(latent_h, latent_w), mode="nearest") loss = loss * mask # Mean-reduce the loss along all dimensions except for the batch dimension. loss = loss.mean(dim=list(range(1, len(loss.shape)))) # Apply min_snr_weights. if min_snr_weights is not None: loss = loss * min_snr_weights # Apply per-example loss weights. if "loss_weight" in data_batch: loss = loss * data_batch["loss_weight"] return loss.mean() def train(config: SdxlLoraConfig, callbacks: list[PipelineCallbacks] | None = None): # noqa: C901 # Give a clear error message if an unsupported base model was chosen. # TODO(ryan): Update this check to work with single-file SD checkpoints. # check_base_model_version( # {BaseModelVersionEnum.STABLE_DIFFUSION_SDXL_BASE}, # config.model, # local_files_only=False, # ) # Create a timestamped directory for all outputs. out_dir = os.path.join(config.base_output_dir, f"{time.time()}") ckpt_dir = os.path.join(out_dir, "checkpoints") os.makedirs(ckpt_dir) accelerator = initialize_accelerator( out_dir, config.gradient_accumulation_steps, config.mixed_precision, config.report_to ) logger = initialize_logging(os.path.basename(__file__), accelerator) # Set the accelerate seed. if config.seed is not None: set_seed(config.seed) # Log the accelerator configuration from every process to help with debugging. logger.info(accelerator.state, main_process_only=False) logger.info("Starting Training.") logger.info(f"Configuration:\n{json.dumps(config.dict(), indent=2, default=str)}") logger.info(f"Output dir: '{out_dir}'") # Write the configuration to disk. with open(os.path.join(out_dir, "config.json"), "w") as f: json.dump(config.dict(), f, indent=2, default=str) weight_dtype = get_dtype_from_str(config.weight_dtype) logger.info("Loading models.") tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, unet = load_models_sdxl( logger=logger, model_name_or_path=config.model, hf_variant=config.hf_variant, vae_model=config.vae_model, base_embeddings=config.base_embeddings, dtype=weight_dtype, ) if config.xformers: import_xformers() # TODO(ryand): There is a known issue if xformers is enabled when training in mixed precision where xformers # will fail because Q, K, V have different dtypes. unet.enable_xformers_memory_efficient_attention() vae.enable_xformers_memory_efficient_attention() # Prepare text encoder output cache. text_encoder_output_cache_dir_name = None if config.cache_text_encoder_outputs: # TODO(ryand): Think about how to better check if it is safe to cache the text encoder outputs. Currently, there # are a number of configurations that would cause variation in the text encoder outputs and should not be used # with caching. if config.train_text_encoder: raise ValueError("'cache_text_encoder_outputs' and 'train_text_encoder' cannot both be True.") # We use a temporary directory for the cache. The directory will automatically be cleaned up when # tmp_text_encoder_output_cache_dir is destroyed. tmp_text_encoder_output_cache_dir = tempfile.TemporaryDirectory() text_encoder_output_cache_dir_name = tmp_text_encoder_output_cache_dir.name if accelerator.is_local_main_process: # Only the main process should populate the cache. logger.info(f"Generating text encoder output cache ('{text_encoder_output_cache_dir_name}').") text_encoder_1.to(accelerator.device, dtype=weight_dtype) text_encoder_2.to(accelerator.device, dtype=weight_dtype) cache_text_encoder_outputs( text_encoder_output_cache_dir_name, config, tokenizer_1, tokenizer_2, text_encoder_1, text_encoder_2 ) # Move the text_encoders back to the CPU, because they are not needed for training. text_encoder_1.to("cpu") text_encoder_2.to("cpu") accelerator.wait_for_everyone() else: text_encoder_1.to(accelerator.device, dtype=weight_dtype) text_encoder_2.to(accelerator.device, dtype=weight_dtype) # Prepare VAE output cache. vae_output_cache_dir_name = None if config.cache_vae_outputs: if config.data_loader.random_flip: raise ValueError("'cache_vae_outputs' cannot be True if 'random_flip' is True.") if not config.data_loader.center_crop: raise ValueError("'cache_vae_outputs' cannot be True if 'center_crop' is False.") # We use a temporary directory for the cache. The directory will automatically be cleaned up when # tmp_vae_output_cache_dir is destroyed. tmp_vae_output_cache_dir = tempfile.TemporaryDirectory() vae_output_cache_dir_name = tmp_vae_output_cache_dir.name if accelerator.is_local_main_process: # Only the main process should to populate the cache. logger.info(f"Generating VAE output cache ('{vae_output_cache_dir_name}').") vae.to(accelerator.device, dtype=weight_dtype) data_loader = _build_data_loader( data_loader_config=config.data_loader, batch_size=config.train_batch_size, use_masks=config.use_masks, shuffle=False, sequential_batching=True, ) cache_vae_outputs(vae_output_cache_dir_name, data_loader, vae) # Move the VAE back to the CPU, because it is not needed for training. vae.to("cpu") accelerator.wait_for_everyone() else: vae.to(accelerator.device, dtype=weight_dtype) unet.to(accelerator.device, dtype=weight_dtype) # Add LoRA layers to the models being trained. trainable_param_groups = [] all_trainable_models: list[peft.PeftModel] = [] def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float | None = None) -> peft.PeftModel: peft_model = peft.get_peft_model(model, lora_config) peft_model.print_trainable_parameters() # Populate `trainable_param_groups`, to be passed to the optimizer. param_group = {"params": list(filter(lambda p: p.requires_grad, peft_model.parameters()))} if lr is not None: param_group["lr"] = lr trainable_param_groups.append(param_group) # Populate all_trainable_models. all_trainable_models.append(peft_model) peft_model.train() return peft_model if config.train_unet: unet_lora_config = peft.LoraConfig( r=config.lora_rank_dim, # TODO(ryand): Diffusers uses lora_alpha=config.lora_rank_dim. Is that preferred? lora_alpha=1.0, target_modules=config.unet_lora_target_modules, ) unet = inject_lora_layers(unet, unet_lora_config, lr=config.unet_learning_rate) if config.train_text_encoder: text_encoder_lora_config = peft.LoraConfig( r=config.lora_rank_dim, lora_alpha=1.0, # init_lora_weights="gaussian", target_modules=config.text_encoder_lora_target_modules, ) text_encoder_1 = inject_lora_layers( text_encoder_1, text_encoder_lora_config, lr=config.text_encoder_learning_rate ) text_encoder_2 = inject_lora_layers( text_encoder_2, text_encoder_lora_config, lr=config.text_encoder_learning_rate ) # If mixed_precision is enabled, cast all trainable params to float32. if config.mixed_precision != "no": for trainable_model in all_trainable_models: for param in trainable_model.parameters(): if param.requires_grad: param.data = param.to(torch.float32) if config.gradient_checkpointing: # We want to enable gradient checkpointing in the UNet regardless of whether it is being trained. unet.enable_gradient_checkpointing() # unet must be in train() mode for gradient checkpointing to take effect. # At the time of writing, the unet dropout probabilities default to 0, so putting the unet in train mode does # not change its forward behavior. unet.train() if config.train_text_encoder: for te in [text_encoder_1, text_encoder_2]: te.gradient_checkpointing_enable() # The text encoders must be in train() mode for gradient checkpointing to take effect. This should # already be the case, since we are training the text_encoders, be we do it explicitly to make it clear # that this is required. # At the time of writing, the text encoder dropout probabilities default to 0, so putting the text # encoders in train mode does not change their forward behavior. te.train() # Set requires_grad = True on the first parameters of the text encoders. Without this, the text encoder # LoRA weights would have 0 gradients, and so would not get trained. Note that the set of # trainable_param_groups has already been populated - the embeddings will not be trained. te.text_model.embeddings.requires_grad_(True) optimizer = initialize_optimizer(config.optimizer, trainable_param_groups) data_loader = _build_data_loader( data_loader_config=config.data_loader, batch_size=config.train_batch_size, use_masks=config.use_masks, text_encoder_output_cache_dir=text_encoder_output_cache_dir_name, vae_output_cache_dir=vae_output_cache_dir_name, ) log_aspect_ratio_buckets(logger=logger, batch_sampler=data_loader.batch_sampler) assert sum([config.max_train_steps is not None, config.max_train_epochs is not None]) == 1 assert sum([config.save_every_n_steps is not None, config.save_every_n_epochs is not None]) == 1 assert sum([config.validate_every_n_steps is not None, config.validate_every_n_epochs is not None]) == 1 # A "step" represents a single weight update operation (i.e. takes into account gradient accumulation steps). # math.ceil(...) is used in calculating the num_steps_per_epoch, because by default an optimizer step is taken when # the end of the dataloader is reached, even if gradient_accumulation_steps hasn't been reached. num_steps_per_epoch = math.ceil(len(data_loader) / config.gradient_accumulation_steps) num_train_steps = config.max_train_steps or config.max_train_epochs * num_steps_per_epoch num_train_epochs = math.ceil(num_train_steps / num_steps_per_epoch) # TODO(ryand): Test in a distributed training environment and more clearly document the rationale for scaling steps # by the number of processes. This scaling logic was copied from the diffusers example training code, but it appears # in many places so I don't know where it originated. Internally, accelerate makes one LR scheduler step per process # (https://github.com/huggingface/accelerate/blame/49cb83a423f2946059117d8bb39b7c8747d29d80/src/accelerate/scheduler.py#L72-L82), # so the scaling here simply reverses that behaviour. lr_scheduler: torch.optim.lr_scheduler.LRScheduler = get_scheduler( config.lr_scheduler, optimizer=optimizer, num_warmup_steps=config.lr_warmup_steps * accelerator.num_processes, num_training_steps=num_train_steps * accelerator.num_processes, ) prepared_result: tuple[ UNet2DConditionModel, peft.PeftModel | CLIPTextModel, peft.PeftModel | CLIPTextModel, torch.optim.Optimizer, torch.utils.data.DataLoader, torch.optim.lr_scheduler.LRScheduler, ] = accelerator.prepare( unet, text_encoder_1, text_encoder_2, optimizer, data_loader, lr_scheduler, # Disable automatic device placement for text_encoder if the text encoder outputs were cached. device_placement=[ True, not config.cache_text_encoder_outputs, not config.cache_text_encoder_outputs, True, True, True, ], ) unet, text_encoder_1, text_encoder_2, optimizer, data_loader, lr_scheduler = prepared_result if accelerator.is_main_process: accelerator.init_trackers("lora_training") # Tensorboard uses markdown formatting, so we wrap the config json in a code block. accelerator.log({"configuration": f"```json\n{json.dumps(config.dict(), indent=2, default=str)}\n```\n"}) checkpoint_tracker = CheckpointTracker( base_dir=ckpt_dir, prefix="checkpoint", max_checkpoints=config.max_checkpoints, extension=".safetensors" if config.lora_checkpoint_format == "kohya" else None, ) # Train! total_batch_size = config.train_batch_size * accelerator.num_processes * config.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num batches = {len(data_loader)}") logger.info(f" Instantaneous batch size per device = {config.train_batch_size}") logger.info(f" Gradient accumulation steps = {config.gradient_accumulation_steps}") logger.info(f" Parallel processes = {accelerator.num_processes}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Total optimization steps = {num_train_steps}") logger.info(f" Total epochs = {num_train_epochs}") global_step = 0 first_epoch = 0 completed_epochs = 0 progress_bar = tqdm( range(global_step, num_train_steps), # Only show the progress bar once on each machine. disable=not accelerator.is_local_main_process, ) progress_bar.set_description("Steps") def save_checkpoint(num_completed_epochs: int, num_completed_steps: int): accelerator.wait_for_everyone() if accelerator.is_main_process: _save_sdxl_lora_checkpoint( epoch=num_completed_epochs, step=num_completed_steps, unet=unet if config.train_unet else None, text_encoder_1=text_encoder_1 if config.train_text_encoder else None, text_encoder_2=text_encoder_2 if config.train_text_encoder else None, logger=logger, checkpoint_tracker=checkpoint_tracker, lora_checkpoint_format=config.lora_checkpoint_format, callbacks=callbacks, ) accelerator.wait_for_everyone() def validate(num_completed_epochs: int, num_completed_steps: int): accelerator.wait_for_everyone() if accelerator.is_main_process: generate_validation_images_sdxl( epoch=num_completed_epochs, step=num_completed_steps, out_dir=out_dir, accelerator=accelerator, vae=vae, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2, tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2, noise_scheduler=noise_scheduler, unet=unet, config=config, logger=logger, callbacks=callbacks, ) accelerator.wait_for_everyone() for epoch in range(first_epoch, num_train_epochs): train_loss = 0.0 for data_batch_idx, data_batch in enumerate(data_loader): with accelerator.accumulate(unet, text_encoder_1, text_encoder_2): loss = train_forward( accelerator=accelerator, data_batch=data_batch, vae=vae, noise_scheduler=noise_scheduler, tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2, unet=unet, weight_dtype=weight_dtype, resolution=config.data_loader.resolution, use_masks=config.use_masks, prediction_type=config.prediction_type, min_snr_gamma=config.min_snr_gamma, ) # Gather the losses across all processes for logging (if we use distributed training). # TODO(ryand): Test that this works properly with distributed training. avg_loss = accelerator.gather(loss.repeat(config.train_batch_size)).mean() train_loss += avg_loss.item() / config.gradient_accumulation_steps # Backpropagate. accelerator.backward(loss) if accelerator.sync_gradients and config.max_grad_norm is not None: params_to_clip = itertools.chain.from_iterable([m.parameters() for m in all_trainable_models]) accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) # Checks if the accelerator has performed an optimization step behind the scenes. if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 completed_epochs = epoch if (data_batch_idx + 1) < len(data_loader) else epoch + 1 log = {"train_loss": train_loss} lrs = lr_scheduler.get_last_lr() if config.train_unet: # When training the UNet, it will always be the first parameter group. log["lr/unet"] = float(lrs[0]) if config.optimizer.optimizer_type == "Prodigy": log["lr/d*lr/unet"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] if config.train_text_encoder: # When training the text encoder, it will always be the last parameter group. log["lr/text_encoder"] = float(lrs[-1]) if config.optimizer.optimizer_type == "Prodigy": log["lr/d*lr/text_encoder"] = optimizer.param_groups[-1]["d"] * optimizer.param_groups[-1]["lr"] accelerator.log(log, step=global_step) train_loss = 0.0 # global_step represents the *number of completed steps* at this point. if config.save_every_n_steps is not None and global_step % config.save_every_n_steps == 0: save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step) if ( config.validate_every_n_steps is not None and global_step % config.validate_every_n_steps == 0 and len(config.validation_prompts) > 0 ): validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step) logs = { "step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], } progress_bar.set_postfix(**logs) if global_step >= num_train_steps: break # Save a checkpoint every n epochs. if config.save_every_n_epochs is not None and completed_epochs % config.save_every_n_epochs == 0: save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step) # Generate validation images every n epochs. if ( config.validate_every_n_epochs is not None and completed_epochs % config.validate_every_n_epochs == 0 and len(config.validation_prompts) > 0 ): validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step) accelerator.end_training() ================================================ FILE: src/invoke_training/pipelines/stable_diffusion_xl/lora_and_textual_inversion/__init__.py ================================================ ================================================ FILE: src/invoke_training/pipelines/stable_diffusion_xl/lora_and_textual_inversion/config.py ================================================ from typing import Literal from pydantic import model_validator from invoke_training.config.base_pipeline_config import BasePipelineConfig from invoke_training.config.data.data_loader_config import TextualInversionSDDataLoaderConfig from invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig, ProdigyOptimizerConfig class SdxlLoraAndTextualInversionConfig(BasePipelineConfig): type: Literal["SDXL_LORA_AND_TEXTUAL_INVERSION"] = "SDXL_LORA_AND_TEXTUAL_INVERSION" model: str = "stabilityai/stable-diffusion-xl-base-1.0" """Name or path of the base model to train. Can be in diffusers format, or a single stable diffusion checkpoint file. (E.g. 'stabilityai/stable-diffusion-xl-base-1.0', '/path/to/JuggernautXL.safetensors', etc. ) """ hf_variant: str | None = "fp16" """The Hugging Face Hub model variant to use. Only applies if `model` is a Hugging Face Hub model name. """ lora_checkpoint_format: Literal["invoke_peft", "kohya"] = "kohya" """The format of the LoRA checkpoint to save. Choose between `invoke_peft` or `kohya`.""" # Helpful discussion for understanding how this works at inference time: # https://github.com/huggingface/diffusers/pull/3144#discussion_r1172413509 num_vectors: int = 1 """Note: `num_vectors` can be overridden by `initial_phrase`. The number of textual inversion embedding vectors that will be used to learn the concept. Increasing the `num_vectors` enables the model to learn more complex concepts, but has the following drawbacks: - greater risk of overfitting - increased size of the resulting output file - consumes more of the prompt capacity at inference time Typical values for `num_vectors` are in the range [1, 16]. As a rule of thumb, `num_vectors` can be increased as the size of the dataset increases (without overfitting). """ placeholder_token: str """The special word to associate the learned embeddings with. Choose a unique token that is unlikely to already exist in the tokenizer's vocabulary. """ initializer_token: str | None = None """A vocabulary token to use as an initializer for the placeholder token. It should be a single word that roughly describes the object or style that you're trying to train on. Must map to a single tokenizer token. For example, if you are training on a dataset of images of your pet dog, a good choice would be `dog`. """ initial_phrase: str | None = None """Note: Exactly one of `initializer_token` or `initial_phrase` should be set. A phrase that will be used to initialize the placeholder token embedding. The phrase will be tokenized, and the corresponding embeddings will be used to initialize the placeholder tokens. The number of embedding vectors will be inferred from the length of the tokenized phrase, so keep the phrase short. The consequences of training a large number of embedding vectors are discussed in the `num_vectors` field documentation. For example, if you are training on a dataset of images of pokemon, you might use `pokemon sketch white background`. """ train_unet: bool = True """Whether to add LoRA layers to the UNet model and train it. """ train_text_encoder: bool = True """Whether to add LoRA layers to the text encoder and train it. """ train_ti: bool = True """Whether to train the textual inversion embeddings.""" ti_train_steps_ratio: float | None = None """The fraction of the total training steps for which the TI embeddings will be trained. For example, if we are training for a total of 5000 steps and `ti_train_steps_ratio=0.5`, then the TI embeddings will be trained for 2500 steps and the will be frozen for the remaining steps. If `None`, then the TI embeddings will be trained for the entire duration of training. """ optimizer: AdamOptimizerConfig | ProdigyOptimizerConfig = AdamOptimizerConfig() text_encoder_learning_rate: float | None = 1e-5 """The learning rate to use for the text encoder model. Set to null or 0 to use the optimizer's default learning rate. """ unet_learning_rate: float | None = 1e-4 """The learning rate to use for the UNet model. Set to null or 0 to use the optimizer's default learning rate. """ textual_inversion_learning_rate: float | None = 1e-3 """The learning rate to use for textual inversion training of the embeddings. Set to null or 0 to use the optimizer's default learning rate. """ lr_scheduler: Literal[ "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup" ] = "constant" lr_warmup_steps: int = 0 """The number of warmup steps in the learning rate scheduler. Only applied to schedulers that support warmup. See lr_scheduler. """ min_snr_gamma: float | None = 5.0 """Min-SNR weighting for diffusion training was introduced in https://arxiv.org/abs/2303.09556. This strategy improves the speed of training convergence by adjusting the weight of each sample. `min_snr_gamma` acts like an an upper bound on the weight of samples with low noise levels. If `None`, then Min-SNR weighting will not be applied. If enabled, the recommended value is `min_snr_gamma = 5.0`. """ lora_rank_dim: int = 4 """The rank dimension to use for the LoRA layers. Increasing the rank dimension increases the model's expressivity, but also increases the size of the generated LoRA model. """ cache_text_encoder_outputs: bool = False """If True, the text encoder(s) will be applied to all of the captions in the dataset before starting training and the results will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the text encoders in VRAM), and speeds up training (don't have to run the text encoders for each training example). This option can only be enabled if `train_text_encoder == False` and there are no caption augmentations being applied. """ cache_vae_outputs: bool = False """If True, the VAE will be applied to all of the images in the dataset before starting training and the results will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the VAE in VRAM), and speeds up training (don't have to run the VAE encoding step). This option can only be enabled if all non-deterministic image augmentations are disabled (i.e. center_crop=True, random_flip=False). """ enable_cpu_offload_during_validation: bool = False """If True, models will be kept in CPU memory and loaded into GPU memory one-by-one while generating validation images. This reduces VRAM requirements at the cost of slower generation of validation images. """ gradient_accumulation_steps: int = 1 """The number of gradient steps to accumulate before each weight update. This value is passed to Hugging Face Accelerate. This is an alternative to increasing the batch size when training with limited VRAM. """ weight_dtype: Literal["float32", "float16", "bfloat16"] = "bfloat16" """All weights (trainable and fixed) will be cast to this precision. Lower precision dtypes require less VRAM, and result in faster training, but are more prone to issues with numerical stability. Recommendations: - `"float32"`: Use this mode if you have plenty of VRAM available. - `"bfloat16"`: Use this mode if you have limited VRAM and a GPU that supports bfloat16. - `"float16"`: Use this mode if you have limited VRAM and a GPU that does not support bfloat16. See also [`mixed_precision`][invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config.SdxlLoraAndTextualInversionConfig.mixed_precision]. """ # noqa: E501 mixed_precision: Literal["no", "fp16", "bf16", "fp8"] = "no" """The mixed precision mode to use. If mixed precision is enabled, then all non-trainable parameters will be cast to the specified `weight_dtype`, and trainable parameters are kept in float32 precision to avoid issues with numerical stability. This value is passed to Hugging Face Accelerate. See [`accelerate.Accelerator.mixed_precision`](https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.mixed_precision) for more details. """ # noqa: E501 xformers: bool = False """If true, use xformers for more efficient attention blocks. """ gradient_checkpointing: bool = False """Whether or not to use gradient checkpointing to save memory at the expense of a slower backward pass. Enabling gradient checkpointing slows down training by ~20%. """ max_checkpoints: int | None = None """The maximum number of checkpoints to keep. New checkpoints will replace earlier checkpoints to stay under this limit. Note that this limit is applied to 'step' and 'epoch' checkpoints separately. """ prediction_type: Literal["epsilon", "v_prediction"] | None = None """The prediction_type that will be used for training. Choose between 'epsilon' or 'v_prediction' or leave 'None'. If 'None', the prediction type of the scheduler: `noise_scheduler.config.prediction_type` is used. """ max_grad_norm: float | None = None """Max gradient norm for clipping. Set to null or 0 for no clipping. """ validation_prompts: list[str] = [] """A list of prompts that will be used to generate images throughout training for the purpose of tracking progress. """ negative_validation_prompts: list[str] | None = None """A list of negative prompts that will be applied when generating validation images. If set, this list should have the same length as 'validation_prompts'. """ num_validation_images_per_prompt: int = 4 """The number of validation images to generate for each prompt in 'validation_prompts'. Careful, validation can become quite slow if this number is too large. """ train_batch_size: int = 4 """The training batch size. """ use_masks: bool = False """If True, image masks will be applied to weight the loss during training. The dataset must contain masks for this feature to be used. """ data_loader: TextualInversionSDDataLoaderConfig """The data configuration. See [`TextualInversionSDDataLoaderConfig`][invoke_training.config.data.data_loader_config.TextualInversionSDDataLoaderConfig] for details. """ vae_model: str | None = None """The name of the Hugging Face Hub VAE model to train against. This will override the VAE bundled with the base model (specified by the `model` parameter). This config option is provided for SDXL models, because SDXL shipped with a VAE that produces NaNs in fp16 mode, so it is common to replace this VAE with a fixed version. """ @model_validator(mode="after") def check_validation_prompts(self): if self.negative_validation_prompts is not None and len(self.negative_validation_prompts) != len( self.validation_prompts ): raise ValueError( f"The number of validation_prompts ({len(self.validation_prompts)}) must match the number of " f"negative_validation_prompts ({len(self.negative_validation_prompts)})." ) return self ================================================ FILE: src/invoke_training/pipelines/stable_diffusion_xl/lora_and_textual_inversion/train.py ================================================ import itertools import json import logging import math import os import time from pathlib import Path from typing import Literal import peft import torch import torch.utils.data from accelerate import Accelerator from accelerate.utils import set_seed from diffusers import UNet2DConditionModel from diffusers.optimization import get_scheduler from tqdm.auto import tqdm from transformers import CLIPTextModel from invoke_training._shared.accelerator.accelerator_utils import ( get_dtype_from_str, initialize_accelerator, initialize_logging, ) from invoke_training._shared.checkpoints.checkpoint_tracker import CheckpointTracker from invoke_training._shared.checkpoints.serialization import save_state_dict from invoke_training._shared.data.data_loaders.textual_inversion_sd_dataloader import ( build_textual_inversion_sd_dataloader, ) from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import log_aspect_ratio_buckets from invoke_training._shared.optimizer.optimizer_utils import initialize_optimizer from invoke_training._shared.stable_diffusion.lora_checkpoint_utils import ( TEXT_ENCODER_TARGET_MODULES, UNET_TARGET_MODULES, save_sdxl_kohya_checkpoint, save_sdxl_peft_checkpoint, ) from invoke_training._shared.stable_diffusion.model_loading_utils import load_models_sdxl from invoke_training._shared.stable_diffusion.textual_inversion import restore_original_embeddings from invoke_training._shared.stable_diffusion.validation import generate_validation_images_sdxl from invoke_training._shared.utils.import_xformers import import_xformers from invoke_training.pipelines.callbacks import ModelCheckpoint, ModelType, PipelineCallbacks, TrainingCheckpoint from invoke_training.pipelines.stable_diffusion_xl.lora.train import train_forward from invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config import ( SdxlLoraAndTextualInversionConfig, ) from invoke_training.pipelines.stable_diffusion_xl.textual_inversion.train import _initialize_placeholder_tokens def _save_sdxl_lora_and_ti_checkpoint( config: SdxlLoraAndTextualInversionConfig, epoch: int, step: int, unet: peft.PeftModel | None, text_encoder_1: peft.PeftModel | None, text_encoder_2: peft.PeftModel | None, placeholder_token_ids_1: list[int], placeholder_token_ids_2: list[int], accelerator: Accelerator, logger: logging.Logger, checkpoint_tracker: CheckpointTracker, lora_checkpoint_format: Literal["invoke_peft", "kohya"], callbacks: list[PipelineCallbacks] | None, ): # Prune checkpoints and get new checkpoint path. num_pruned = checkpoint_tracker.prune(1) if num_pruned > 0: logger.info(f"Pruned {num_pruned} checkpoint(s).") save_path = checkpoint_tracker.get_path(epoch=epoch, step=step) training_checkpoint = TrainingCheckpoint(models=[], epoch=epoch, step=step) if lora_checkpoint_format == "invoke_peft": save_sdxl_peft_checkpoint( Path(save_path), unet=unet if config.train_unet else None, text_encoder_1=text_encoder_1 if config.train_text_encoder else None, text_encoder_2=text_encoder_2 if config.train_text_encoder else None, ) training_checkpoint.models.append(ModelCheckpoint(file_path=save_path, model_type=ModelType.SDXL_LORA_PEFT)) elif lora_checkpoint_format == "kohya": save_sdxl_kohya_checkpoint( Path(save_path) / "lora.safetensors", unet=unet if config.train_unet else None, text_encoder_1=text_encoder_1 if config.train_text_encoder else None, text_encoder_2=text_encoder_2 if config.train_text_encoder else None, ) training_checkpoint.models.append(ModelCheckpoint(file_path=save_path, model_type=ModelType.SDXL_LORA_KOHYA)) else: raise ValueError(f"Unsupported lora_checkpoint_format: '{lora_checkpoint_format}'.") if config.train_ti: ti_checkpoint_path = Path(save_path) / "embeddings.safetensors" learned_embeds_1 = ( accelerator.unwrap_model(text_encoder_1) .get_input_embeddings() .weight[min(placeholder_token_ids_1) : max(placeholder_token_ids_1) + 1] ) learned_embeds_2 = ( accelerator.unwrap_model(text_encoder_2) .get_input_embeddings() .weight[min(placeholder_token_ids_2) : max(placeholder_token_ids_2) + 1] ) learned_embeds_dict = { "clip_l": learned_embeds_1.detach().cpu().to(dtype=torch.float32), "clip_g": learned_embeds_2.detach().cpu().to(dtype=torch.float32), } save_state_dict(learned_embeds_dict, ti_checkpoint_path) training_checkpoint.models.append( ModelCheckpoint(file_path=ti_checkpoint_path, model_type=ModelType.SDXL_TEXTUAL_INVERSION) ) if callbacks is not None: for cb in callbacks: cb.on_save_checkpoint(training_checkpoint) def train(config: SdxlLoraAndTextualInversionConfig, callbacks: list[PipelineCallbacks] | None = None): # noqa: C901 # Give a clear error message if an unsupported base model was chosen. # TODO(ryan): Update this check to work with single-file SD checkpoints. # check_base_model_version( # {BaseModelVersionEnum.STABLE_DIFFUSION_SDXL_BASE}, # config.model, # local_files_only=False, # ) # Create a timestamped directory for all outputs. out_dir = os.path.join(config.base_output_dir, f"{time.time()}") ckpt_dir = os.path.join(out_dir, "checkpoints") os.makedirs(ckpt_dir) accelerator = initialize_accelerator( out_dir, config.gradient_accumulation_steps, config.mixed_precision, config.report_to ) logger = initialize_logging(os.path.basename(__file__), accelerator) # Set the accelerate seed. if config.seed is not None: set_seed(config.seed) # Log the accelerator configuration from every process to help with debugging. logger.info(accelerator.state, main_process_only=False) logger.info("Starting Training.") logger.info(f"Configuration:\n{json.dumps(config.dict(), indent=2, default=str)}") logger.info(f"Output dir: '{out_dir}'") # Write the configuration to disk. with open(os.path.join(out_dir, "config.json"), "w") as f: json.dump(config.dict(), f, indent=2, default=str) weight_dtype = get_dtype_from_str(config.weight_dtype) logger.info("Loading models.") tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, unet = load_models_sdxl( logger=logger, model_name_or_path=config.model, hf_variant=config.hf_variant, vae_model=config.vae_model, dtype=weight_dtype, ) if config.xformers: import_xformers() # TODO(ryand): There is a known issue if xformers is enabled when training in mixed precision where xformers # will fail because Q, K, V have different dtypes. unet.enable_xformers_memory_efficient_attention() vae.enable_xformers_memory_efficient_attention() # Prepare text encoder output cache. # text_encoder_output_cache_dir_name = None if config.cache_text_encoder_outputs: raise NotImplementedError("Caching text encoder outputs is not yet supported.") else: text_encoder_1.to(accelerator.device, dtype=weight_dtype) text_encoder_2.to(accelerator.device, dtype=weight_dtype) # Prepare VAE output cache. vae_output_cache_dir_name = None if config.cache_vae_outputs: raise NotImplementedError("Caching VAE outputs is not yet supported.") else: vae.to(accelerator.device, dtype=weight_dtype) unet.to(accelerator.device, dtype=weight_dtype) # Add LoRA layers to the models being trained. trainable_param_groups = [] all_trainable_models: set[torch.nn.Module] = set() def inject_lora_layers(model, lora_config: peft.LoraConfig, lr: float) -> peft.PeftModel: peft_model = peft.get_peft_model(model, lora_config) peft_model.print_trainable_parameters() # Populate `trainable_param_groups`, to be passed to the optimizer. param_group = {"params": list(filter(lambda p: p.requires_grad, peft_model.parameters())), "lr": lr} trainable_param_groups.append(param_group) # Populate all_trainable_models. all_trainable_models.add(peft_model) peft_model.train() return peft_model if config.train_unet: unet_lora_config = peft.LoraConfig( r=config.lora_rank_dim, # TODO(ryand): Diffusers uses lora_alpha=config.lora_rank_dim. Is that preferred? lora_alpha=1.0, target_modules=UNET_TARGET_MODULES, ) unet = inject_lora_layers(unet, unet_lora_config, lr=config.unet_learning_rate) if config.train_text_encoder: text_encoder_lora_config = peft.LoraConfig( r=config.lora_rank_dim, lora_alpha=1.0, # init_lora_weights="gaussian", target_modules=TEXT_ENCODER_TARGET_MODULES, ) text_encoder_1 = inject_lora_layers( text_encoder_1, text_encoder_lora_config, lr=config.text_encoder_learning_rate ) text_encoder_2 = inject_lora_layers( text_encoder_2, text_encoder_lora_config, lr=config.text_encoder_learning_rate ) if config.train_ti: # TODO(ryand): Move this private function to a shared location. placeholder_tokens, placeholder_token_ids_1, placeholder_token_ids_2 = _initialize_placeholder_tokens( config=config, tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2, logger=logger, ) logger.info(f"Initialized {len(placeholder_tokens)} placeholder tokens: {placeholder_tokens}.") # Unfreeze the token embeddings in the text encoders. text_encoder_1.text_model.embeddings.token_embedding.requires_grad_(True) text_encoder_2.text_model.embeddings.token_embedding.requires_grad_(True) all_trainable_models.add(text_encoder_1) all_trainable_models.add(text_encoder_2) for te in [text_encoder_1, text_encoder_2]: param_group = { "params": te.get_input_embeddings().parameters(), "lr": config.textual_inversion_learning_rate, } trainable_param_groups.append(param_group) # If mixed_precision is enabled, cast all trainable params to float32. if config.mixed_precision != "no": for trainable_model in all_trainable_models: for param in trainable_model.parameters(): if param.requires_grad: param.data = param.to(torch.float32) if config.gradient_checkpointing: # We want to enable gradient checkpointing in the UNet regardless of whether it is being trained. unet.enable_gradient_checkpointing() # unet must be in train() mode for gradient checkpointing to take effect. # At the time of writing, the unet dropout probabilities default to 0, so putting the unet in train mode does # not change its forward behavior. unet.train() if config.train_text_encoder: for te in [text_encoder_1, text_encoder_2]: te.gradient_checkpointing_enable() # The text encoders must be in train() mode for gradient checkpointing to take effect. This should # already be the case, since we are training the text_encoders, be we do it explicitly to make it clear # that this is required. # At the time of writing, the text encoder dropout probabilities default to 0, so putting the text # encoders in train mode does not change their forward behavior. te.train() # Set requires_grad = True on the first parameters of the text encoders. Without this, the text encoder # LoRA weights would have 0 gradients, and so would not get trained. Note that the set of # trainable_param_groups has already been populated - this won't change what gets trained. te.text_model.embeddings.requires_grad_(True) optimizer = initialize_optimizer(config.optimizer, trainable_param_groups) data_loader = build_textual_inversion_sd_dataloader( config=config.data_loader, placeholder_token=config.placeholder_token, batch_size=config.train_batch_size, use_masks=config.use_masks, vae_output_cache_dir=vae_output_cache_dir_name, ) log_aspect_ratio_buckets(logger=logger, batch_sampler=data_loader.batch_sampler) assert sum([config.max_train_steps is not None, config.max_train_epochs is not None]) == 1 assert sum([config.save_every_n_steps is not None, config.save_every_n_epochs is not None]) == 1 assert sum([config.validate_every_n_steps is not None, config.validate_every_n_epochs is not None]) == 1 # A "step" represents a single weight update operation (i.e. takes into account gradient accumulation steps). # math.ceil(...) is used in calculating the num_steps_per_epoch, because by default an optimizer step is taken when # the end of the dataloader is reached, even if gradient_accumulation_steps hasn't been reached. num_steps_per_epoch = math.ceil(len(data_loader) / config.gradient_accumulation_steps) num_train_steps = config.max_train_steps or config.max_train_epochs * num_steps_per_epoch num_train_epochs = math.ceil(num_train_steps / num_steps_per_epoch) # TODO(ryand): Test in a distributed training environment and more clearly document the rationale for scaling steps # by the number of processes. This scaling logic was copied from the diffusers example training code, but it appears # in many places so I don't know where it originated. Internally, accelerate makes one LR scheduler step per process # (https://github.com/huggingface/accelerate/blame/49cb83a423f2946059117d8bb39b7c8747d29d80/src/accelerate/scheduler.py#L72-L82), # so the scaling here simply reverses that behaviour. lr_scheduler: torch.optim.lr_scheduler.LRScheduler = get_scheduler( config.lr_scheduler, optimizer=optimizer, num_warmup_steps=config.lr_warmup_steps * accelerator.num_processes, num_training_steps=num_train_steps * accelerator.num_processes, ) prepared_result: tuple[ UNet2DConditionModel, peft.PeftModel | CLIPTextModel, peft.PeftModel | CLIPTextModel, torch.optim.Optimizer, torch.utils.data.DataLoader, torch.optim.lr_scheduler.LRScheduler, ] = accelerator.prepare( unet, text_encoder_1, text_encoder_2, optimizer, data_loader, lr_scheduler, # Disable automatic device placement for text_encoder if the text encoder outputs were cached. device_placement=[ True, not config.cache_text_encoder_outputs, not config.cache_text_encoder_outputs, True, True, True, ], ) unet, text_encoder_1, text_encoder_2, optimizer, data_loader, lr_scheduler = prepared_result if accelerator.is_main_process: accelerator.init_trackers("lora_and_ti_training") # Tensorboard uses markdown formatting, so we wrap the config json in a code block. accelerator.log({"configuration": f"```json\n{json.dumps(config.dict(), indent=2, default=str)}\n```\n"}) checkpoint_tracker = CheckpointTracker( base_dir=ckpt_dir, prefix="checkpoint", max_checkpoints=config.max_checkpoints, ) # Train! total_batch_size = config.train_batch_size * accelerator.num_processes * config.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num batches = {len(data_loader)}") logger.info(f" Instantaneous batch size per device = {config.train_batch_size}") logger.info(f" Gradient accumulation steps = {config.gradient_accumulation_steps}") logger.info(f" Parallel processes = {accelerator.num_processes}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Total optimization steps = {num_train_steps}") logger.info(f" Total epochs = {num_train_epochs}") global_step = 0 first_epoch = 0 completed_epochs = first_epoch progress_bar = tqdm( range(global_step, num_train_steps), # Only show the progress bar once on each machine. disable=not accelerator.is_local_main_process, ) progress_bar.set_description("Steps") ti_train_steps = num_train_steps if config.ti_train_steps_ratio is not None: ti_train_steps = math.ceil(num_train_steps * config.ti_train_steps_ratio) logger.info(f"The TI training pivot point is set at {ti_train_steps} steps.") # Keep original embeddings as reference. with torch.no_grad(): orig_embeds_params_1 = accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight.data.clone() orig_embeds_params_2 = accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight.data.clone() def save_checkpoint(num_completed_epochs: int, num_completed_steps: int): accelerator.wait_for_everyone() if accelerator.is_main_process: _save_sdxl_lora_and_ti_checkpoint( config=config, epoch=num_completed_epochs, step=num_completed_steps, unet=unet, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2, placeholder_token_ids_1=placeholder_token_ids_1, placeholder_token_ids_2=placeholder_token_ids_2, accelerator=accelerator, logger=logger, checkpoint_tracker=checkpoint_tracker, lora_checkpoint_format=config.lora_checkpoint_format, callbacks=callbacks, ) accelerator.wait_for_everyone() def validate(num_completed_epochs: int, num_completed_steps: int): accelerator.wait_for_everyone() if accelerator.is_main_process: generate_validation_images_sdxl( epoch=num_completed_epochs, step=num_completed_steps, out_dir=out_dir, accelerator=accelerator, vae=vae, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2, tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2, noise_scheduler=noise_scheduler, unet=unet, config=config, logger=logger, callbacks=callbacks, ) accelerator.wait_for_everyone() for epoch in range(first_epoch, num_train_epochs): # TODO(ryand): Is this necessary? text_encoder_1.train() text_encoder_2.train() train_loss = 0.0 for data_batch_idx, data_batch in enumerate(data_loader): if global_step == ti_train_steps and config.train_ti: logger.info("Reached TI training pivot point. Setting TI learning rate to 0.0.") # TODO(ryand): The TI embeddings continue to be updated slightly by the normalization step in # restore_original_embeddings(...). The updates should be very small and converge quickly, so this # should be fine. But, at some point we should tidy this up. for ti_param_group in optimizer.param_groups[-2:]: # The TI param groups should be the last two param groups. But, this is pretty brittle, so this # assertion adds a bit of safety. assert len(ti_param_group["params"]) == 1 ti_param_group["lr"] = 0.0 with accelerator.accumulate(unet, text_encoder_1, text_encoder_2): loss = train_forward( accelerator=accelerator, data_batch=data_batch, vae=vae, noise_scheduler=noise_scheduler, tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2, unet=unet, weight_dtype=weight_dtype, resolution=config.data_loader.resolution, use_masks=config.use_masks, prediction_type=config.prediction_type, min_snr_gamma=config.min_snr_gamma, ) # Gather the losses across all processes for logging (if we use distributed training). # TODO(ryand): Test that this works properly with distributed training. avg_loss = accelerator.gather(loss.repeat(config.train_batch_size)).mean() train_loss += avg_loss.item() / config.gradient_accumulation_steps # Backpropagate. accelerator.backward(loss) if accelerator.sync_gradients and config.max_grad_norm is not None: params_to_clip = itertools.chain.from_iterable([m.parameters() for m in all_trainable_models]) accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) # Make sure we don't update any embedding weights besides the newly-added token(s). # TODO(ryand): Should we only do this if accelerator.sync_gradients? restore_original_embeddings( tokenizer=tokenizer_1, placeholder_token_ids=placeholder_token_ids_1, accelerator=accelerator, text_encoder=text_encoder_1, orig_embeds_params=orig_embeds_params_1, ) restore_original_embeddings( tokenizer=tokenizer_2, placeholder_token_ids=placeholder_token_ids_2, accelerator=accelerator, text_encoder=text_encoder_2, orig_embeds_params=orig_embeds_params_2, ) # Checks if the accelerator has performed an optimization step behind the scenes. if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 completed_epochs = epoch if (data_batch_idx + 1) < len(data_loader) else epoch + 1 log = {"train_loss": train_loss} lrs = lr_scheduler.get_last_lr() # Prepare LR names in the same order that their respective param groups were added to the optimizer. # TODO: Do this at the time that we prepare the param groups? lr_names = [] if config.train_unet: lr_names.append("unet") if config.train_text_encoder: lr_names.append("text_encoder_1") lr_names.append("text_encoder_2") if config.train_ti: lr_names.append("ti_embeddings_1") lr_names.append("ti_embeddings_2") for lr_idx, lr_name in enumerate(lr_names): log[f"lr/{lr_name}"] = float(lrs[lr_idx]) if config.optimizer.optimizer_type == "Prodigy": log[f"lr/d*lr/{lr_name}"] = ( optimizer.param_groups[lr_idx]["d"] * optimizer.param_groups[lr_idx]["lr"] ) accelerator.log(log, step=global_step) train_loss = 0.0 # global_step represents the *number of completed steps* at this point. if config.save_every_n_steps is not None and global_step % config.save_every_n_steps == 0: save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step) if ( config.validate_every_n_steps is not None and global_step % config.validate_every_n_steps == 0 and len(config.validation_prompts) > 0 ): validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step) logs = { "step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], } progress_bar.set_postfix(**logs) if global_step >= num_train_steps: break # Save a checkpoint every n epochs. if config.save_every_n_epochs is not None and completed_epochs % config.save_every_n_epochs == 0: save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step) # Generate validation images every n epochs. if ( config.validate_every_n_epochs is not None and completed_epochs % config.validate_every_n_epochs == 0 and len(config.validation_prompts) > 0 ): validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step) accelerator.end_training() ================================================ FILE: src/invoke_training/pipelines/stable_diffusion_xl/textual_inversion/__init__.py ================================================ ================================================ FILE: src/invoke_training/pipelines/stable_diffusion_xl/textual_inversion/config.py ================================================ from typing import Literal from pydantic import model_validator from invoke_training.config.base_pipeline_config import BasePipelineConfig from invoke_training.config.data.data_loader_config import TextualInversionSDDataLoaderConfig from invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig, ProdigyOptimizerConfig class SdxlTextualInversionConfig(BasePipelineConfig): type: Literal["SDXL_TEXTUAL_INVERSION"] = "SDXL_TEXTUAL_INVERSION" """Must be `SDXL_TEXTUAL_INVERSION`. This is what differentiates training pipeline types. """ model: str = "stabilityai/stable-diffusion-xl-base-1.0" """Name or path of the base model to train. Can be in diffusers format, or a single stable diffusion checkpoint file. (E.g. 'stabilityai/stable-diffusion-xl-base-1.0', '/path/to/JuggernautXL.safetensors', etc. ) """ hf_variant: str | None = "fp16" """The Hugging Face Hub model variant to use. Only applies if `model` is a Hugging Face Hub model name. """ # Helpful discussion for understanding how this works at inference time: # https://github.com/huggingface/diffusers/pull/3144#discussion_r1172413509 num_vectors: int = 1 """Note: `num_vectors` can be overridden by `initial_phrase`. The number of textual inversion embedding vectors that will be used to learn the concept. Increasing the `num_vectors` enables the model to learn more complex concepts, but has the following drawbacks: - greater risk of overfitting - increased size of the resulting output file - consumes more of the prompt capacity at inference time Typical values for `num_vectors` are in the range [1, 16]. As a rule of thumb, `num_vectors` can be increased as the size of the dataset increases (without overfitting). """ placeholder_token: str """The special word to associate the learned embeddings with. Choose a unique token that is unlikely to already exist in the tokenizer's vocabulary. """ initializer_token: str | None = None """Note: Exactly one of `initializer_token`, `initial_embedding_file`, or `initial_phrase` should be set. A vocabulary token to use as an initializer for the placeholder token. It should be a single word that roughly describes the object or style that you're trying to train on. Must map to a single tokenizer token. For example, if you are training on a dataset of images of your pet dog, a good choice would be `dog`. """ initial_embedding_file: str | None = None """Note: Exactly one of `initializer_token`, `initial_embedding_file`, or `initial_phrase` should be set. Path to an existing TI embedding that will be used to initialize the embedding being trained. The placeholder token in the file must match the `placeholder_token` field. Either `initializer_token` or `initial_embedding_file` should be set. """ initial_phrase: str | None = None """Note: Exactly one of `initializer_token`, `initial_embedding_file`, or `initial_phrase` should be set. A phrase that will be used to initialize the placeholder token embedding. The phrase will be tokenized, and the corresponding embeddings will be used to initialize the placeholder tokens. The number of embedding vectors will be inferred from the length of the tokenized phrase, so keep the phrase short. The consequences of training a large number of embedding vectors are discussed in the `num_vectors` field documentation. For example, if you are training on a dataset of images of pokemon, you might use `pokemon sketch white background`. """ optimizer: AdamOptimizerConfig | ProdigyOptimizerConfig = AdamOptimizerConfig() lr_scheduler: Literal[ "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup" ] = "constant" lr_warmup_steps: int = 0 """The number of warmup steps in the learning rate scheduler. Only applied to schedulers that support warmup. See lr_scheduler. """ min_snr_gamma: float | None = 5.0 """Min-SNR weighting for diffusion training was introduced in https://arxiv.org/abs/2303.09556. This strategy improves the speed of training convergence by adjusting the weight of each sample. `min_snr_gamma` acts like an an upper bound on the weight of samples with low noise levels. If `None`, then Min-SNR weighting will not be applied. If enabled, the recommended value is `min_snr_gamma = 5.0`. """ cache_vae_outputs: bool = False """If True, the VAE will be applied to all of the images in the dataset before starting training and the results will be cached to disk. This reduces the VRAM requirements during training (don't have to keep the VAE in VRAM), and speeds up training (don't have to run the VAE encoding step). This option can only be enabled if all non-deterministic image augmentations are disabled (i.e. `center_crop=True`, `random_flip=False`, etc.). """ enable_cpu_offload_during_validation: bool = False """If True, models will be kept in CPU memory and loaded into GPU memory one-by-one while generating validation images. This reduces VRAM requirements at the cost of slower generation of validation images. """ gradient_accumulation_steps: int = 1 """The number of gradient steps to accumulate before each weight update. This is an alternative to increasing the `train_batch_size` when training with limited VRAM. """ weight_dtype: Literal["float32", "float16", "bfloat16"] = "bfloat16" """All weights (trainable and fixed) will be cast to this precision. Lower precision dtypes require less VRAM, and result in faster training, but are more prone to issues with numerical stability. Recommendations: - `"float32"`: Use this mode if you have plenty of VRAM available. - `"bfloat16"`: Use this mode if you have limited VRAM and a GPU that supports bfloat16. - `"float16"`: Use this mode if you have limited VRAM and a GPU that does not support bfloat16. See also [`mixed_precision`][invoke_training.pipelines.stable_diffusion_xl.textual_inversion.config.SdxlTextualInversionConfig.mixed_precision]. """ # noqa: E501 mixed_precision: Literal["no", "fp16", "bf16", "fp8"] = "no" """The mixed precision mode to use. If mixed precision is enabled, then all non-trainable parameters will be cast to the specified `weight_dtype`, and trainable parameters are kept in float32 precision to avoid issues with numerical stability. This value is passed to Hugging Face Accelerate. See [`accelerate.Accelerator.mixed_precision`](https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.mixed_precision) for more details. """ # noqa: E501 xformers: bool = False """If `True`, use xformers for more efficient attention blocks. """ gradient_checkpointing: bool = False """Whether or not to use gradient checkpointing to save memory at the expense of a slower backward pass. Enabling gradient checkpointing slows down training by ~20%. """ max_checkpoints: int | None = None """The maximum number of checkpoints to keep. New checkpoints will replace earlier checkpoints to stay under this limit. Note that this limit is applied to 'step' and 'epoch' checkpoints separately. """ prediction_type: Literal["epsilon", "v_prediction"] | None = None """The prediction type that will be used for training. If `None`, the prediction type will be inferred from the scheduler. """ max_grad_norm: float | None = None """Maximum gradient norm for gradient clipping. Set to `None` for no clipping. """ validation_prompts: list[str] = [] """A list of prompts that will be used to generate images throughout training for the purpose of tracking progress. """ negative_validation_prompts: list[str] | None = None """A list of negative prompts that will be applied when generating validation images. If set, this list should have the same length as 'validation_prompts'. """ num_validation_images_per_prompt: int = 4 """The number of validation images to generate for each prompt in `validation_prompts`. Careful, validation can become very slow if this number is too large. """ train_batch_size: int = 4 """The training batch size. """ use_masks: bool = False """If True, image masks will be applied to weight the loss during training. The dataset must contain masks for this feature to be used. """ data_loader: TextualInversionSDDataLoaderConfig """The data configuration. See [`TextualInversionSDDataLoaderConfig`][invoke_training.config.data.data_loader_config.TextualInversionSDDataLoaderConfig] for details. """ vae_model: str | None = None """The name of the Hugging Face Hub VAE model to train against. If set, this will override the VAE bundled with the base model (specified by the `model` parameter). This config option is provided for SDXL models, because SDXL 1.0 shipped with a VAE that produces NaNs in fp16 mode, so it is common to replace this VAE with a fixed version. """ @model_validator(mode="after") def check_validation_prompts(self): if self.negative_validation_prompts is not None and len(self.negative_validation_prompts) != len( self.validation_prompts ): raise ValueError( f"The number of validation_prompts ({len(self.validation_prompts)}) must match the number of " f"negative_validation_prompts ({len(self.negative_validation_prompts)})." ) return self ================================================ FILE: src/invoke_training/pipelines/stable_diffusion_xl/textual_inversion/train.py ================================================ import json import logging import math import os import tempfile import time import torch import torch.utils.data from accelerate import Accelerator from accelerate.utils import set_seed from diffusers.optimization import get_scheduler from tqdm.auto import tqdm from transformers import CLIPPreTrainedModel, CLIPTextModel, CLIPTokenizer, PreTrainedTokenizer from invoke_training._shared.accelerator.accelerator_utils import ( get_dtype_from_str, initialize_accelerator, initialize_logging, ) from invoke_training._shared.checkpoints.checkpoint_tracker import CheckpointTracker from invoke_training._shared.checkpoints.serialization import save_state_dict from invoke_training._shared.data.data_loaders.textual_inversion_sd_dataloader import ( build_textual_inversion_sd_dataloader, ) from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import log_aspect_ratio_buckets from invoke_training._shared.optimizer.optimizer_utils import initialize_optimizer from invoke_training._shared.stable_diffusion.model_loading_utils import load_models_sdxl from invoke_training._shared.stable_diffusion.textual_inversion import ( initialize_placeholder_tokens_from_initial_phrase, initialize_placeholder_tokens_from_initializer_token, restore_original_embeddings, ) from invoke_training._shared.stable_diffusion.validation import generate_validation_images_sdxl from invoke_training._shared.utils.import_xformers import import_xformers from invoke_training.pipelines.callbacks import ModelCheckpoint, ModelType, PipelineCallbacks, TrainingCheckpoint from invoke_training.pipelines.stable_diffusion_xl.lora.train import cache_vae_outputs, train_forward from invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config import ( SdxlLoraAndTextualInversionConfig, ) from invoke_training.pipelines.stable_diffusion_xl.textual_inversion.config import SdxlTextualInversionConfig def _save_ti_embeddings( epoch: int, step: int, text_encoder_1: CLIPTextModel, text_encoder_2: CLIPTextModel, placeholder_token_ids_1: list[int], placeholder_token_ids_2: list[int], accelerator: Accelerator, logger: logging.Logger, checkpoint_tracker: CheckpointTracker, callbacks: list[PipelineCallbacks] | None, ): """Save a Textual Inversion SDXL checkpoint. Old checkpoints are deleted if necessary to respect the checkpoint_tracker limits. """ # Prune checkpoints and get new checkpoint path. num_pruned = checkpoint_tracker.prune(1) if num_pruned > 0: logger.info(f"Pruned {num_pruned} checkpoint(s).") save_path = checkpoint_tracker.get_path(epoch=epoch, step=step) learned_embeds_1 = ( accelerator.unwrap_model(text_encoder_1) .get_input_embeddings() .weight[min(placeholder_token_ids_1) : max(placeholder_token_ids_1) + 1] ) learned_embeds_2 = ( accelerator.unwrap_model(text_encoder_2) .get_input_embeddings() .weight[min(placeholder_token_ids_2) : max(placeholder_token_ids_2) + 1] ) learned_embeds_dict = { "clip_l": learned_embeds_1.detach().cpu().to(dtype=torch.float32), "clip_g": learned_embeds_2.detach().cpu().to(dtype=torch.float32), } save_state_dict(learned_embeds_dict, save_path) if callbacks is not None: for cb in callbacks: cb.on_save_checkpoint( TrainingCheckpoint( models=[ModelCheckpoint(file_path=save_path, model_type=ModelType.SDXL_TEXTUAL_INVERSION)], epoch=epoch, step=step, ) ) def _initialize_placeholder_tokens( config: SdxlTextualInversionConfig | SdxlLoraAndTextualInversionConfig, tokenizer_1: CLIPTokenizer, tokenizer_2: CLIPTokenizer, text_encoder_1: PreTrainedTokenizer, text_encoder_2: PreTrainedTokenizer, logger: logging.Logger, ) -> tuple[list[str], list[int], list[int]]: """Prepare the tokenizers and text_encoders for TI training. - Add the placeholder tokens to the tokenizers. - Add new token embeddings to the text_encoders for each of the placeholder tokens. - Initialize the new token embeddings from either an existing token, or an initial TI embedding file. """ if ( sum( [ getattr(config, "initializer_token", None) is not None, getattr(config, "initial_embedding_file", None) is not None, getattr(config, "initial_phrase", None) is not None, ] ) != 1 ): raise ValueError( "Exactly one of 'initializer_token', 'initial_embedding_file', or 'initial_phrase' should be set." ) if getattr(config, "initializer_token", None) is not None: placeholder_tokens_1, placeholder_token_ids_1 = initialize_placeholder_tokens_from_initializer_token( tokenizer=tokenizer_1, text_encoder=text_encoder_1, initializer_token=config.initializer_token, placeholder_token=config.placeholder_token, num_vectors=config.num_vectors, logger=logger, ) placeholder_tokens_2, placeholder_token_ids_2 = initialize_placeholder_tokens_from_initializer_token( tokenizer=tokenizer_2, text_encoder=text_encoder_2, initializer_token=config.initializer_token, placeholder_token=config.placeholder_token, num_vectors=config.num_vectors, logger=logger, ) elif getattr(config, "initial_embedding_file", None) is not None: # TODO(ryan) raise NotImplementedError("Initializing from an initial embedding is not yet supported for SDXL.") elif getattr(config, "initial_phrase", None) is not None: placeholder_tokens_1, placeholder_token_ids_1 = initialize_placeholder_tokens_from_initial_phrase( tokenizer=tokenizer_1, text_encoder=text_encoder_1, initial_phrase=config.initial_phrase, placeholder_token=config.placeholder_token, ) placeholder_tokens_2, placeholder_token_ids_2 = initialize_placeholder_tokens_from_initial_phrase( tokenizer=tokenizer_2, text_encoder=text_encoder_2, initial_phrase=config.initial_phrase, placeholder_token=config.placeholder_token, ) else: raise ValueError( "Exactly one of 'initializer_token', 'initial_embedding_file', or 'initial_phrase' should be set." ) assert placeholder_tokens_1 == placeholder_tokens_2 return placeholder_tokens_1, placeholder_token_ids_1, placeholder_token_ids_2 def train(config: SdxlTextualInversionConfig, callbacks: list[PipelineCallbacks] | None = None): # noqa: C901 # Create a timestamped directory for all outputs. out_dir = os.path.join(config.base_output_dir, f"{time.time()}") ckpt_dir = os.path.join(out_dir, "checkpoints") os.makedirs(ckpt_dir) accelerator = initialize_accelerator( out_dir, config.gradient_accumulation_steps, config.mixed_precision, config.report_to ) logger = initialize_logging(os.path.basename(__file__), accelerator) # Set the accelerate seed. if config.seed is not None: set_seed(config.seed) # Log the accelerator configuration from every process to help with debugging. logger.info(accelerator.state, main_process_only=False) logger.info("Starting Training.") logger.info(f"Configuration:\n{json.dumps(config.dict(), indent=2, default=str)}") logger.info(f"Output dir: '{out_dir}'") # Write the configuration to disk. with open(os.path.join(out_dir, "config.json"), "w") as f: json.dump(config.dict(), f, indent=2, default=str) weight_dtype = get_dtype_from_str(config.weight_dtype) logger.info("Loading models.") tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, unet = load_models_sdxl( logger=logger, model_name_or_path=config.model, hf_variant=config.hf_variant, vae_model=config.vae_model, dtype=weight_dtype, ) placeholder_tokens, placeholder_token_ids_1, placeholder_token_ids_2 = _initialize_placeholder_tokens( config=config, tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2, logger=logger, ) logger.info(f"Initialized {len(placeholder_tokens)} placeholder tokens: {placeholder_tokens}.") # All parameters of the VAE, UNet, and text encoder are currently frozen. Just unfreeze the token embeddings in the # text encoders. text_encoder_1.text_model.embeddings.token_embedding.requires_grad_(True) text_encoder_2.text_model.embeddings.token_embedding.requires_grad_(True) if config.gradient_checkpointing: # We want to enable gradient checkpointing in the UNet regardless of whether it is being trained. unet.enable_gradient_checkpointing() # unet must be in train() mode for gradient checkpointing to take effect. # At the time of writing, the unet dropout probabilities default to 0, so putting the unet in train mode does # not change its forward behavior. unet.train() for te in [text_encoder_1, text_encoder_2]: # The text_encoder will be put in .train() mode later, so we don't need to worry about that here. # Note: There are some weird interactions gradient checkpointing and requires_grad_() when training a # text_encoder LoRA. If this code ever gets copied elsewhere, make sure to take a look at how this is # handled in other training pipelines. te.gradient_checkpointing_enable() if config.xformers: import_xformers() unet.enable_xformers_memory_efficient_attention() vae.enable_xformers_memory_efficient_attention() # Prepare VAE output cache. vae_output_cache_dir_name = None if config.cache_vae_outputs: if config.data_loader.random_flip: raise ValueError("'cache_vae_outputs' cannot be True if 'random_flip' is True.") if not config.data_loader.center_crop: raise ValueError("'cache_vae_outputs' cannot be True if 'center_crop' is False.") # We use a temporary directory for the cache. The directory will automatically be cleaned up when # tmp_vae_output_cache_dir is destroyed. tmp_vae_output_cache_dir = tempfile.TemporaryDirectory() vae_output_cache_dir_name = tmp_vae_output_cache_dir.name if accelerator.is_local_main_process: # Only the main process should to populate the cache. logger.info(f"Generating VAE output cache ('{vae_output_cache_dir_name}').") vae.to(accelerator.device, dtype=weight_dtype) data_loader = build_textual_inversion_sd_dataloader( config=config.data_loader, placeholder_token=config.placeholder_token, batch_size=config.train_batch_size, use_masks=config.use_masks, shuffle=False, ) cache_vae_outputs(vae_output_cache_dir_name, data_loader, vae) # Move the VAE back to the CPU, because it is not needed for training. vae.to("cpu") accelerator.wait_for_everyone() else: vae.to(accelerator.device, dtype=weight_dtype) unet.to(accelerator.device, dtype=weight_dtype) text_encoder_1.to(accelerator.device, dtype=weight_dtype) text_encoder_2.to(accelerator.device, dtype=weight_dtype) # Initialize the optimizer to only optimize the token embeddings. trainable_param_groups = [ {"params": text_encoder_1.get_input_embeddings().parameters()}, {"params": text_encoder_2.get_input_embeddings().parameters()}, ] optimizer = initialize_optimizer(config.optimizer, trainable_param_groups) trainable_models = torch.nn.ModuleDict({"text_encoder_1": text_encoder_1, "text_encoder_2": text_encoder_2}) data_loader = build_textual_inversion_sd_dataloader( config=config.data_loader, placeholder_token=config.placeholder_token, batch_size=config.train_batch_size, use_masks=config.use_masks, vae_output_cache_dir=vae_output_cache_dir_name, ) log_aspect_ratio_buckets(logger=logger, batch_sampler=data_loader.batch_sampler) assert sum([config.max_train_steps is not None, config.max_train_epochs is not None]) == 1 assert sum([config.save_every_n_steps is not None, config.save_every_n_epochs is not None]) == 1 assert sum([config.validate_every_n_steps is not None, config.validate_every_n_epochs is not None]) == 1 # A "step" represents a single weight update operation (i.e. takes into account gradient accumulation steps). # math.ceil(...) is used in calculating the num_steps_per_epoch, because by default an optimizer step is taken when # the end of the dataloader is reached, even if gradient_accumulation_steps hasn't been reached. num_steps_per_epoch = math.ceil(len(data_loader) / config.gradient_accumulation_steps) num_train_steps = config.max_train_steps or config.max_train_epochs * num_steps_per_epoch num_train_epochs = math.ceil(num_train_steps / num_steps_per_epoch) # TODO(ryand): Test in a distributed training environment and more clearly document the rationale for scaling steps # by the number of processes. This scaling logic was copied from the diffusers example training code, but it appears # in many places so I don't know where it originated. Internally, accelerate makes one LR scheduler step per process # (https://github.com/huggingface/accelerate/blame/49cb83a423f2946059117d8bb39b7c8747d29d80/src/accelerate/scheduler.py#L72-L82), # so the scaling here simply reverses that behaviour. lr_scheduler: torch.optim.lr_scheduler.LRScheduler = get_scheduler( config.lr_scheduler, optimizer=optimizer, num_warmup_steps=config.lr_warmup_steps * accelerator.num_processes, num_training_steps=num_train_steps * accelerator.num_processes, ) prepared_result: tuple[ CLIPPreTrainedModel, CLIPPreTrainedModel, torch.optim.Optimizer, torch.utils.data.DataLoader, torch.optim.lr_scheduler.LRScheduler, ] = accelerator.prepare(text_encoder_1, text_encoder_2, optimizer, data_loader, lr_scheduler) text_encoder_1, text_encoder_2, optimizer, data_loader, lr_scheduler = prepared_result if accelerator.is_main_process: accelerator.init_trackers("textual_inversion_training") # Tensorboard uses markdown formatting, so we wrap the config json in a code block. accelerator.log({"configuration": f"```json\n{json.dumps(config.dict(), indent=2, default=str)}\n```\n"}) checkpoint_tracker = CheckpointTracker( base_dir=ckpt_dir, prefix="checkpoint", extension=".safetensors", max_checkpoints=config.max_checkpoints, ) # Train! total_batch_size = config.train_batch_size * accelerator.num_processes * config.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num batches = {len(data_loader)}") logger.info(f" Instantaneous batch size per device = {config.train_batch_size}") logger.info(f" Gradient accumulation steps = {config.gradient_accumulation_steps}") logger.info(f" Parallel processes = {accelerator.num_processes}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Total optimization steps = {num_train_steps}") logger.info(f" Total epochs = {num_train_epochs}") global_step = 0 first_epoch = 0 completed_epochs = 0 progress_bar = tqdm( range(global_step, num_train_steps), # Only show the progress bar once on each machine. disable=not accelerator.is_local_main_process, ) progress_bar.set_description("Steps") # Keep original embeddings as reference. with torch.no_grad(): orig_embeds_params_1 = accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight.data.clone() orig_embeds_params_2 = accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight.data.clone() def save_checkpoint(num_completed_epochs: int, num_completed_steps: int): accelerator.wait_for_everyone() if accelerator.is_main_process: _save_ti_embeddings( epoch=num_completed_epochs, step=num_completed_steps, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2, placeholder_token_ids_1=placeholder_token_ids_1, placeholder_token_ids_2=placeholder_token_ids_2, accelerator=accelerator, logger=logger, checkpoint_tracker=checkpoint_tracker, callbacks=callbacks, ) accelerator.wait_for_everyone() def validate(num_completed_epochs: int, num_completed_steps: int): accelerator.wait_for_everyone() if accelerator.is_main_process: generate_validation_images_sdxl( epoch=num_completed_epochs, step=num_completed_steps, out_dir=out_dir, accelerator=accelerator, vae=vae, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2, tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2, noise_scheduler=noise_scheduler, unet=unet, config=config, logger=logger, callbacks=callbacks, ) accelerator.wait_for_everyone() for epoch in range(first_epoch, num_train_epochs): text_encoder_1.train() text_encoder_2.train() train_loss = 0.0 for data_batch_idx, data_batch in enumerate(data_loader): with accelerator.accumulate(trainable_models): loss = train_forward( accelerator=accelerator, data_batch=data_batch, vae=vae, noise_scheduler=noise_scheduler, tokenizer_1=tokenizer_1, tokenizer_2=tokenizer_2, text_encoder_1=text_encoder_1, text_encoder_2=text_encoder_2, unet=unet, weight_dtype=weight_dtype, resolution=config.data_loader.resolution, use_masks=config.use_masks, prediction_type=config.prediction_type, min_snr_gamma=config.min_snr_gamma, ) # Gather the losses across all processes for logging (if we use distributed training). # TODO(ryand): Test that this works properly with distributed training. avg_loss = accelerator.gather(loss.repeat(config.train_batch_size)).mean() train_loss += avg_loss.item() / config.gradient_accumulation_steps # Backpropagate. accelerator.backward(loss) if accelerator.sync_gradients and config.max_grad_norm is not None: # TODO(ryand): I copied this from another pipeline. Should probably just clip the trainable params. params_to_clip = trainable_models.parameters() accelerator.clip_grad_norm_(params_to_clip, config.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) # Make sure we don't update any embedding weights besides the newly-added token(s). # TODO(ryand): Should we only do this if accelerator.sync_gradients? restore_original_embeddings( tokenizer=tokenizer_1, placeholder_token_ids=placeholder_token_ids_1, accelerator=accelerator, text_encoder=text_encoder_1, orig_embeds_params=orig_embeds_params_1, ) restore_original_embeddings( tokenizer=tokenizer_2, placeholder_token_ids=placeholder_token_ids_2, accelerator=accelerator, text_encoder=text_encoder_2, orig_embeds_params=orig_embeds_params_2, ) # Checks if the accelerator has performed an optimization step behind the scenes. if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 completed_epochs = epoch if (data_batch_idx + 1) < len(data_loader) else epoch + 1 log = {"train_loss": train_loss, "lr": lr_scheduler.get_last_lr()[0]} if config.optimizer.optimizer_type == "Prodigy": # TODO(ryand): Test Prodigy logging. log["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] accelerator.log(log, step=global_step) train_loss = 0.0 # global_step represents the *number of completed steps* at this point. if config.save_every_n_steps is not None and global_step % config.save_every_n_steps == 0: save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step) if ( config.validate_every_n_steps is not None and global_step % config.validate_every_n_steps == 0 and len(config.validation_prompts) > 0 ): validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step) logs = { "step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], } progress_bar.set_postfix(**logs) if global_step >= num_train_steps: break # Save a checkpoint every n epochs. if config.save_every_n_epochs is not None and completed_epochs % config.save_every_n_epochs == 0: save_checkpoint(num_completed_epochs=completed_epochs, num_completed_steps=global_step) # Generate validation images every n epochs. if ( config.validate_every_n_epochs is not None and completed_epochs % config.validate_every_n_epochs == 0 and len(config.validation_prompts) > 0 ): validate(num_completed_epochs=completed_epochs, num_completed_steps=global_step) accelerator.end_training() ================================================ FILE: src/invoke_training/sample_configs/_experimental/sd_dpo_lora_pickapic_1x24gb.yaml ================================================ # Training mode: Direct Preference Optimization LoRA Training # Dataset: A small subset of the pickapic_v2 dataset. # Base model: SD 1.5 # GPU: 1 x 24GB # # Training takes ~2 hours on a single RTX 4090. type: SD_DIRECT_PREFERENCE_OPTIMIZATION_LORA seed: 1 base_output_dir: output/dpo optimizer: optimizer_type: AdamW learning_rate: 1e-4 weight_decay: 1e-2 lr_warmup_steps: 200 lr_scheduler: cosine data_loader: type: IMAGE_PAIR_PREFERENCE_SD_DATA_LOADER dataset: type: HF_HUB_IMAGE_PAIR_PREFERENCE_DATASET resolution: 512 # General model: runwayml/stable-diffusion-v1-5 gradient_accumulation_steps: 2 weight_dtype: float16 mixed_precision: fp16 gradient_checkpointing: True max_train_steps: 5000 save_every_n_epochs: 1 save_every_n_steps: null max_checkpoints: 100 validation_prompts: - A monk in an orange robe by a round window in a spaceship in dramatic lighting - A galaxy-colored figurine is floating over the sea at sunset, photorealistic - Concept art of a mythical sky alligator with wings, nature documentary validate_every_n_epochs: 1 train_batch_size: 4 num_validation_images_per_prompt: 1 ================================================ FILE: src/invoke_training/sample_configs/_experimental/sd_dpo_lora_refinement_pokemon_1x24gb.yaml ================================================ # Training mode: Direct Preference Optimization LoRA Training # Base model: SD 1.5 # GPU: 1 x 24GB type: SD_DIRECT_PREFERENCE_OPTIMIZATION_LORA seed: 1 base_output_dir: output/dpo optimizer: optimizer_type: AdamW learning_rate: 1e-4 weight_decay: 1e-2 lr_warmup_steps: 500 lr_scheduler: cosine data_loader: type: IMAGE_PAIR_PREFERENCE_SD_DATA_LOADER dataset: type: IMAGE_PAIR_PREFERENCE_DATASET dataset_dir: output/pokemon_pairs resolution: 512 dataloader_num_workers: 4 # General model: runwayml/stable-diffusion-v1-5 initial_lora: output/sd_lora_pokemon/1704824279.2765746/checkpoint_epoch-00000003 gradient_accumulation_steps: 2 weight_dtype: float16 mixed_precision: fp16 gradient_checkpointing: True max_train_steps: 5000 save_every_n_epochs: 10 save_every_n_steps: null max_checkpoints: 100 validation_prompts: - A cute yoda pokemon creature. - A cute astronaut pokemon creature. validate_every_n_epochs: 10 train_batch_size: 4 num_validation_images_per_prompt: 2 ================================================ FILE: src/invoke_training/sample_configs/flux_lora_1x40gb.yaml ================================================ # Training mode: LoRA # Base model: Flux.1-dev # Dataset: Bruce the Gnome # GPU: 1 x 40GB type: FLUX_LORA seed: 1 base_output_dir: output/experiments/bruce_the_gnome/flux_lora optimizer: optimizer_type: AdamW learning_rate: 1e-4 lr_warmup_steps: 1 lr_scheduler: constant transformer_learning_rate: 4e-4 text_encoder_learning_rate: 4e-4 train_text_encoder: False data_loader: type: IMAGE_CAPTION_FLUX_DATA_LOADER dataset: type: IMAGE_CAPTION_JSONL_DATASET # Update the jsonl_path field to point to the metadata.jsonl file of the downloaded dataset. jsonl_path: sample_data/bruce_the_gnome/data.jsonl resolution: 768 aspect_ratio_buckets: target_resolution: 768 start_dim: 384 end_dim: 1536 divisible_by: 128 caption_prefix: "bruce the gnome" dataloader_num_workers: 4 # General model: black-forest-labs/FLUX.1-dev gradient_accumulation_steps: 1 weight_dtype: bfloat16 gradient_checkpointing: True max_train_steps: 350 save_every_n_steps: 50 validate_every_n_steps: 50 max_checkpoints: 10 validation_prompts: - A stuffed gnome at the beach with a pina colada in its hand. - A stuffed gnome reading a book in a cozy library. - A stuffed gnome sitting in a garden surrounded by colorful flowers and butterflies. train_batch_size: 4 num_validation_images_per_prompt: 3 ================================================ FILE: src/invoke_training/sample_configs/sd_lora_baroque_1x8gb.yaml ================================================ # Training mode: Finetuning with LoRA # Base model: SD 1.5 # Dataset: https://huggingface.co/datasets/InvokeAI/nga-baroque # GPU: 1 x 8GB # Instructions: # 1. Download the dataset from https://huggingface.co/datasets/InvokeAI/nga-baroque. # 2. Update the `jsonl_path` field in the `data_loader` section to point to the `metadata.jsonl` file of the downloaded # dataset. # Notes: # This config file has been optimized for the primary goal of achieving reasonable results *quickly* for demo purposes. type: SD_LORA seed: 1 base_output_dir: output/baroque/sd_lora optimizer: optimizer_type: Prodigy learning_rate: 1.0 weight_decay: 0.01 use_bias_correction: True safeguard_warmup: True data_loader: type: IMAGE_CAPTION_SD_DATA_LOADER dataset: type: IMAGE_CAPTION_JSONL_DATASET # Update the jsonl_path field to point to the metadata.jsonl file of the downloaded dataset. jsonl_path: data/nga-baroque/metadata.jsonl resolution: 512 aspect_ratio_buckets: target_resolution: 512 start_dim: 256 end_dim: 768 divisible_by: 64 caption_prefix: "A baroque painting of" dataloader_num_workers: 4 # General model: runwayml/stable-diffusion-v1-5 gradient_accumulation_steps: 1 weight_dtype: bfloat16 gradient_checkpointing: True max_train_epochs: 15 save_every_n_epochs: 1 validate_every_n_epochs: 1 max_checkpoints: 5 validation_prompts: - A baroque painting of a woman carrying a basket of fruit. - A baroque painting of a cute Yoda creature. train_batch_size: 4 num_validation_images_per_prompt: 3 ================================================ FILE: src/invoke_training/sample_configs/sd_textual_inversion_gnome_1x8gb.yaml ================================================ # Training mode: Textual Inversion # Base model: SD v1 # GPU: 1 x 24GB type: SD_TEXTUAL_INVERSION seed: 1 base_output_dir: output/sd_ti_bruce_the_gnome optimizer: optimizer_type: AdamW learning_rate: 4e-3 lr_warmup_steps: 200 lr_scheduler: cosine data_loader: type: TEXTUAL_INVERSION_SD_DATA_LOADER dataset: type: IMAGE_DIR_DATASET dataset_dir: "sample_data/bruce_the_gnome" keep_in_memory: True caption_preset: object resolution: 512 center_crop: True random_flip: False shuffle_caption_delimiter: null aspect_ratio_buckets: target_resolution: 512 start_dim: 256 end_dim: 768 divisible_by: 64 dataloader_num_workers: 4 # General model: runwayml/stable-diffusion-v1-5 num_vectors: 4 placeholder_token: "bruce_the_gnome" initializer_token: "gnome" cache_vae_outputs: False gradient_accumulation_steps: 1 weight_dtype: bfloat16 gradient_checkpointing: True max_train_steps: 2000 save_every_n_steps: 200 validate_every_n_steps: 200 max_checkpoints: 20 validation_prompts: - A photo of bruce_the_gnome at the beach - A photo of bruce_the_gnome reading a book train_batch_size: 1 num_validation_images_per_prompt: 3 ================================================ FILE: src/invoke_training/sample_configs/sdxl_finetune_baroque_1x24gb.yaml ================================================ # Training mode: Full Finetuning # Base model: SDXL # Dataset: https://huggingface.co/datasets/InvokeAI/nga-baroque # GPU: 1 x 24GB # Instructions: # 1. Download the dataset from https://huggingface.co/datasets/InvokeAI/nga-baroque. # 2. Update the `jsonl_path` field in the `data_loader` section to point to the `metadata.jsonl` file of the downloaded # dataset. type: SDXL_FINETUNE seed: 1 base_output_dir: output/baroque/sdxl_finetune optimizer: optimizer_type: AdamW learning_rate: 5e-5 weight_decay: 1e-3 use_8bit: True lr_scheduler: constant_with_warmup lr_warmup_steps: 500 data_loader: type: IMAGE_CAPTION_SD_DATA_LOADER dataset: type: IMAGE_CAPTION_JSONL_DATASET # Update the jsonl_path field to point to the metadata.jsonl file of the downloaded dataset. jsonl_path: data/nga-baroque/metadata.jsonl resolution: 1024 aspect_ratio_buckets: target_resolution: 1024 start_dim: 512 end_dim: 1536 divisible_by: 128 caption_prefix: "A baroque style painting," # General model: stabilityai/stable-diffusion-xl-base-1.0 save_checkpoint_format: trained_only_diffusers # vae_model: madebyollin/sdxl-vae-fp16-fix save_dtype: float16 gradient_accumulation_steps: 1 weight_dtype: bfloat16 gradient_checkpointing: True cache_vae_outputs: True cache_text_encoder_outputs: True max_train_epochs: 50 save_every_n_epochs: 3 validate_every_n_epochs: 3 # We save a max of 1 checkpoint for demo purposes, because the checkpoints take up a lot of disk space. max_checkpoints: 1 validation_prompts: - A baroque style painting of a woman carrying a basket of fruit. - A baroque style painting of a cute Yoda creature. train_batch_size: 4 num_validation_images_per_prompt: 3 ================================================ FILE: src/invoke_training/sample_configs/sdxl_finetune_robocats_1x24gb.yaml ================================================ # Training mode: Full finetune # Base model: SDXL # Dataset: Robocats # GPU: 1 x 24GB type: SDXL_FINETUNE seed: 1 base_output_dir: output/robocats/sdxl_finetune optimizer: optimizer_type: AdamW learning_rate: 2e-5 use_8bit: True lr_scheduler: constant_with_warmup lr_warmup_steps: 200 data_loader: type: IMAGE_CAPTION_SD_DATA_LOADER dataset: type: IMAGE_CAPTION_JSONL_DATASET # Update the jsonl_path field to point to the metadata.jsonl file of the downloaded dataset. jsonl_path: /home/ryan/data/robocats/data.jsonl resolution: 1024 aspect_ratio_buckets: target_resolution: 1024 start_dim: 512 end_dim: 1536 divisible_by: 128 caption_prefix: "In the robocat style," # General model: stabilityai/stable-diffusion-xl-base-1.0 save_checkpoint_format: trained_only_diffusers # vae_model: madebyollin/sdxl-vae-fp16-fix save_dtype: float16 gradient_accumulation_steps: 1 weight_dtype: bfloat16 gradient_checkpointing: True cache_vae_outputs: True cache_text_encoder_outputs: True max_train_steps: 2000 validate_every_n_steps: 200 save_every_n_steps: 2000 # We save a max of 1 checkpoint for demo purposes, because the checkpoints take up a lot of disk space. max_checkpoints: 1 validation_prompts: - In the robocat style, a robotic lion in the jungle. - In the robocat style, a hamburger and fries. train_batch_size: 4 num_validation_images_per_prompt: 3 ================================================ FILE: src/invoke_training/sample_configs/sdxl_lora_and_ti_gnome_1x24gb.yaml ================================================ # Training mode: Finetuning with LoRA and Textual Inversion # Base model: SDXL 1.0 # GPU: 1 x 24GB type: SDXL_LORA_AND_TEXTUAL_INVERSION seed: 1 base_output_dir: output/sdxl_lora_and_ti_bruce_the_gnome optimizer: optimizer_type: AdamW learning_rate: 2e-3 lr_warmup_steps: 200 lr_scheduler: constant data_loader: type: TEXTUAL_INVERSION_SD_DATA_LOADER dataset: type: IMAGE_DIR_DATASET dataset_dir: "sample_data/bruce_the_gnome" keep_in_memory: True caption_preset: object resolution: 1024 center_crop: True random_flip: False shuffle_caption_delimiter: null dataloader_num_workers: 4 # General model: stabilityai/stable-diffusion-xl-base-1.0 vae_model: madebyollin/sdxl-vae-fp16-fix num_vectors: 2 placeholder_token: "bruce_the_gnome" initializer_token: "gnome" cache_vae_outputs: False gradient_accumulation_steps: 1 weight_dtype: bfloat16 gradient_checkpointing: True max_train_steps: 2000 save_every_n_steps: 200 validate_every_n_steps: 200 max_checkpoints: 50 validation_prompts: - A photo of bruce_the_gnome at the beach - A photo of bruce_the_gnome reading a book train_batch_size: 1 num_validation_images_per_prompt: 3 ================================================ FILE: src/invoke_training/sample_configs/sdxl_lora_baroque_1x24gb.yaml ================================================ # Training mode: Finetuning with LoRA # Base model: SDXL 1.0 # Dataset: https://huggingface.co/datasets/InvokeAI/nga-baroque # GPU: 1 x 24GB # Instructions: # 1. Download the dataset from https://huggingface.co/datasets/InvokeAI/nga-baroque. # 2. Update the `jsonl_path` field in the `data_loader` section to point to the `metadata.jsonl` file of the downloaded # dataset. # Notes: # This config file has been optimized for the primary goal of achieving reasonable results *quickly* for demo # purposes. type: SDXL_LORA seed: 1 base_output_dir: output/baroque/sdxl_lora optimizer: optimizer_type: AdamW learning_rate: 1e-3 data_loader: type: IMAGE_CAPTION_SD_DATA_LOADER dataset: type: IMAGE_CAPTION_JSONL_DATASET # Update the jsonl_path field to point to the metadata.jsonl file of the downloaded dataset. jsonl_path: data/nga-baroque/metadata_masks.jsonl resolution: 1024 aspect_ratio_buckets: target_resolution: 1024 start_dim: 512 end_dim: 1536 divisible_by: 128 caption_prefix: "A baroque painting of" # General model: stabilityai/stable-diffusion-xl-base-1.0 # vae_model: madebyollin/sdxl-vae-fp16-fix gradient_accumulation_steps: 1 weight_dtype: bfloat16 gradient_checkpointing: True cache_vae_outputs: True max_train_epochs: 16 save_every_n_epochs: 2 validate_every_n_epochs: 2 use_masks: False max_checkpoints: 5 validation_prompts: - A baroque painting of a woman carrying a basket of fruit. - A baroque painting of a cute Yoda creature. train_batch_size: 4 num_validation_images_per_prompt: 3 ================================================ FILE: src/invoke_training/sample_configs/sdxl_lora_baroque_1x8gb.yaml ================================================ # Training mode: Finetuning with LoRA # Base model: SDXL 1.0 # Dataset: https://huggingface.co/datasets/InvokeAI/nga-baroque # GPU: 1 x 8GB # Instructions: # 1. Download the dataset from https://huggingface.co/datasets/InvokeAI/nga-baroque. # 2. Update the `jsonl_path` field in the `data_loader` section to point to the `metadata.jsonl` file of the downloaded # dataset. # Notes: # This config file has been optimized for 2 primary goals: # - Minimize VRAM usage so that an SDXL model can be trained with only 8GB of VRAM. # - Achieve reasonable results *quickly* for demo purposes. type: SDXL_LORA seed: 1 base_output_dir: output/baroque/sdxl_lora optimizer: optimizer_type: Prodigy learning_rate: 1.0 weight_decay: 0.01 use_bias_correction: True safeguard_warmup: True data_loader: type: IMAGE_CAPTION_SD_DATA_LOADER dataset: type: IMAGE_CAPTION_JSONL_DATASET # Update the jsonl_path field to point to the metadata.jsonl file of the downloaded dataset. jsonl_path: data/nga-baroque/metadata.jsonl # TODO: More optimizations are needed to train at full 1024x1024 resolution with 8GB VRAM. resolution: 512 # aspect_ratio_buckets: # target_resolution: 1024 # start_dim: 512 # end_dim: 1536 # divisible_by: 128 caption_prefix: "A baroque painting of" # General model: stabilityai/stable-diffusion-xl-base-1.0 vae_model: madebyollin/sdxl-vae-fp16-fix train_text_encoder: False cache_text_encoder_outputs: True enable_cpu_offload_during_validation: True gradient_accumulation_steps: 4 weight_dtype: bfloat16 gradient_checkpointing: True max_train_epochs: 6 save_every_n_epochs: 1 validate_every_n_epochs: 1 max_checkpoints: 5 validation_prompts: - A baroque painting of a woman carrying a basket of fruit. - A baroque painting of a cute Yoda creature. train_batch_size: 1 num_validation_images_per_prompt: 3 ================================================ FILE: src/invoke_training/sample_configs/sdxl_lora_masks_gnome_1x24gb.yaml ================================================ # Training mode: LoRA with masks # Base model: SDXL 1.0 # Dataset: Bruce the Gnome # GPU: 1 x 24GB type: SDXL_LORA seed: 1 base_output_dir: output/bruce/sdxl_lora_masks optimizer: optimizer_type: AdamW learning_rate: 7e-5 lr_scheduler: constant_with_warmup lr_warmup_steps: 50 data_loader: type: IMAGE_CAPTION_SD_DATA_LOADER dataset: type: IMAGE_CAPTION_JSONL_DATASET jsonl_path: sample_data/bruce_the_gnome/data_masks.jsonl resolution: 1024 aspect_ratio_buckets: target_resolution: 1024 start_dim: 512 end_dim: 1536 divisible_by: 128 # General model: stabilityai/stable-diffusion-xl-base-1.0 # vae_model: madebyollin/sdxl-vae-fp16-fix gradient_accumulation_steps: 1 weight_dtype: bfloat16 gradient_checkpointing: True cache_vae_outputs: True max_train_steps: 500 save_every_n_steps: 50 validate_every_n_steps: 50 use_masks: True max_checkpoints: 5 validation_prompts: - A stuffed gnome at the beach with a pina colada in its hand. - A stuffed gnome reading a book in a cozy library. train_batch_size: 4 num_validation_images_per_prompt: 3 ================================================ FILE: src/invoke_training/sample_configs/sdxl_textual_inversion_gnome_1x24gb.yaml ================================================ # Training mode: Textual Inversion # Base model: SDXL # GPU: 1 x 24GB type: SDXL_TEXTUAL_INVERSION seed: 1 base_output_dir: output/bruce/sdxl_ti optimizer: optimizer_type: AdamW learning_rate: 2e-3 lr_warmup_steps: 200 lr_scheduler: cosine data_loader: type: TEXTUAL_INVERSION_SD_DATA_LOADER dataset: type: IMAGE_DIR_DATASET dataset_dir: "sample_data/bruce_the_gnome" keep_in_memory: True caption_preset: object resolution: 1024 center_crop: True random_flip: False shuffle_caption_delimiter: null dataloader_num_workers: 4 # General model: stabilityai/stable-diffusion-xl-base-1.0 vae_model: madebyollin/sdxl-vae-fp16-fix num_vectors: 4 placeholder_token: "bruce_the_gnome" initializer_token: "gnome" cache_vae_outputs: False gradient_accumulation_steps: 1 weight_dtype: bfloat16 gradient_checkpointing: True max_train_steps: 2000 save_every_n_steps: 200 validate_every_n_steps: 200 max_checkpoints: 20 validation_prompts: - A photo of bruce_the_gnome at the beach - A photo of bruce_the_gnome reading a book train_batch_size: 1 num_validation_images_per_prompt: 3 ================================================ FILE: src/invoke_training/sample_configs/sdxl_textual_inversion_masks_gnome_1x24gb.yaml ================================================ # Training mode: Textual Inversion with Masks # Base model: SDXL # GPU: 1 x 24GB type: SDXL_TEXTUAL_INVERSION seed: 1 base_output_dir: output/bruce/sdxl_ti_masks optimizer: optimizer_type: AdamW learning_rate: 5e-4 lr_scheduler: constant_with_warmup lr_warmup_steps: 50 data_loader: type: TEXTUAL_INVERSION_SD_DATA_LOADER dataset: type: IMAGE_CAPTION_JSONL_DATASET jsonl_path: sample_data/bruce_the_gnome/data_masks.jsonl keep_in_memory: True caption_preset: object resolution: 1024 center_crop: True random_flip: False shuffle_caption_delimiter: null # General model: stabilityai/stable-diffusion-xl-base-1.0 num_vectors: 16 placeholder_token: "bruce_the_gnome" initializer_token: "gnome" cache_vae_outputs: False gradient_accumulation_steps: 1 weight_dtype: bfloat16 gradient_checkpointing: True max_train_steps: 500 save_every_n_steps: 50 validate_every_n_steps: 50 use_masks: True max_checkpoints: 10 validation_prompts: - A photo of bruce_the_gnome at the beach with a pina colada in its hand. - A photo of bruce_the_gnome reading a book in a cozy library. train_batch_size: 4 num_validation_images_per_prompt: 3 ================================================ FILE: src/invoke_training/scripts/__init__.py ================================================ ================================================ FILE: src/invoke_training/scripts/_experimental/auto_caption/auto_caption_images.py ================================================ import argparse import json from pathlib import Path import torch import torch.utils.data from PIL import Image from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer from invoke_training.scripts.utils.image_dir_dataset import ImageDirDataset, list_collate_fn def select_device_and_dtype(force_cpu: bool = False) -> tuple[torch.device, torch.dtype]: if force_cpu: return torch.device("cpu"), torch.float32 if torch.cuda.is_available(): return torch.device("cuda"), torch.float16 return torch.device("cpu"), torch.float32 def process_images(images: list[Image.Image], prompt: str, moondream, tokenizer) -> list[str]: # image_embeds = moondream.encode_image(image).to(device=device) # answer = moondream.answer_question(image_embeds, prompt, tokenizer) answers = moondream.batch_answer( images=images, prompts=[prompt] * len(images), tokenizer=tokenizer, ) return answers def main( prompt: str, use_cpu: bool, batch_size: int, output_path: str, dataset: torch.utils.data.Dataset, ): device, dtype = select_device_and_dtype(use_cpu) print(f"Using device: {device}") print(f"Using dtype: {dtype}") # Check that the output file does not already exist before spending time generating captions. out_path = Path(output_path) if out_path.exists(): raise FileExistsError(f"Output file already exists: {out_path}") # Load the model. model_id = "vikhyatk/moondream2" model_revision = "2024-04-02" tokenizer = AutoTokenizer.from_pretrained(model_id, revision=model_revision) # TODO(ryand): Warn about security implications of trust_remote_code=True. moondream_model: AutoModelForCausalLM = AutoModelForCausalLM.from_pretrained( model_id, trust_remote_code=True, revision=model_revision ).to(device=device, dtype=dtype) moondream_model.eval() data_loader = torch.utils.data.DataLoader( dataset, collate_fn=list_collate_fn, batch_size=batch_size, drop_last=False ) results = [] for image_batch in tqdm(data_loader): image_paths = image_batch["image_path"] answers = process_images(image_batch["image"], prompt, moondream_model, tokenizer) for image_path, answer in zip(image_paths, answers, strict=True): results.append({"image": image_path, "text": answer}) # Check that the output file does not exist immediately before writing to it. if out_path.exists(): raise FileExistsError(f"Output file already exists: {out_path}") with open(out_path, "w") as outfile: for entry in results: json.dump(entry, outfile) outfile.write("\n") print("Output saved to output.jsonl.") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the moondream captioning model on a directory of images.") parser.add_argument("--dir", type=str, required=True, help="Directory containing images.") parser.add_argument( "--prompt", type=str, default="Describe this image in 20 words or less.", help="(Optional) Prompt for the model.", ) parser.add_argument( "--cpu", action="store_true", default=False, help="Force use of CPU instead of GPU. If not set, a GPU will be used if available.", ) parser.add_argument( "--batch-size", type=int, default=4, help="Batch size for processing images. To maximize speed, set this to the largest value that fits in GPU " "memory.", ) parser.add_argument( "--output", type=str, default="output.jsonl", help="(Optional) Path to the output file. Default is 'output.jsonl'.", ) args = parser.parse_args() # Prepare the dataset. dataset = ImageDirDataset(args.dir) print(f"Found {len(dataset)} images in '{args.dir}'.") main(args.prompt, args.cpu, args.batch_size, args.output, dataset) ================================================ FILE: src/invoke_training/scripts/_experimental/masks/clipseg.py ================================================ import torch from PIL import Image from transformers import AutoProcessor, CLIPSegForImageSegmentation, CLIPSegProcessor def load_clipseg_model() -> tuple[CLIPSegProcessor, CLIPSegForImageSegmentation]: # Load the model. clipseg_processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") clipseg_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") return clipseg_processor, clipseg_model def run_clipseg( images: list[Image.Image], prompt: str, clipseg_processor, clipseg_model, clipseg_temp: float, device: torch.device, ) -> list[Image.Image]: """Run ClipSeg on a list of images. Args: clipseg_temp (float): Temperature applied to the CLIPSeg logits. Higher values cause the mask to be 'smoother' and include more of the background. Recommended range: 0.5 to 1.0. """ orig_image_sizes = [img.size for img in images] prompts = [prompt] * len(images) # TODO(ryand): Should we run the same image with and without the prompt to normalize for any bias in the model? inputs = clipseg_processor(text=prompts, images=images, padding=True, return_tensors="pt") # Move inputs and clipseg_model to the correct device and dtype. inputs = {k: v.to(device=device) for k, v in inputs.items()} clipseg_model = clipseg_model.to(device=device) outputs = clipseg_model(**inputs) logits = outputs.logits if logits.ndim == 2: # The model squeezes the batch dimension if it's 1, so we need to unsqueeze it. logits = logits.unsqueeze(0) probs = torch.nn.functional.sigmoid(logits / clipseg_temp) # Normalize each mask to 0-255. Note that each mask is normalized independently. probs = 255 * probs / probs.amax(dim=(1, 2), keepdim=True) # Make mask greyscale. masks: list[Image.Image] = [] for prob, orig_size in zip(probs, orig_image_sizes, strict=True): mask = Image.fromarray(prob.cpu().numpy()).convert("L") mask = mask.resize(orig_size) masks.append(mask) return masks def select_device() -> torch.device: if torch.cuda.is_available(): return torch.device("cuda") return torch.device("cpu") ================================================ FILE: src/invoke_training/scripts/_experimental/masks/generate_masks.py ================================================ import argparse from pathlib import Path import torch import torch.utils.data from tqdm import tqdm from invoke_training.scripts._experimental.masks.clipseg import load_clipseg_model, run_clipseg, select_device from invoke_training.scripts.utils.image_dir_dataset import ImageDirDataset, list_collate_fn @torch.no_grad() def generate_masks(image_dir: str, prompt: str, clipseg_temp: float, batch_size: int): """Generate masks for a directory of images. Args: image_dir (str): The directory containing images. prompt (str): A short description of the thing you want to mask. E.g. 'a cat'. clipseg_temp (float): Temperature applied to the CLIPSeg logits. Higher values cause the mask to be 'smoother'. and include more of the background. Recommended range: 0.5 to 1.0. batch_size (int): Batch size to use when processing images. Larger batch sizes may be faster but require more. """ device = select_device() clipseg_processor, clipseg_model = load_clipseg_model() # Prepare the dataloader. dataset = ImageDirDataset(image_dir) print(f"Found {len(dataset)} images in '{image_dir}'.") data_loader = torch.utils.data.DataLoader( dataset, collate_fn=list_collate_fn, batch_size=batch_size, drop_last=False ) # Process each image. for batch in tqdm(data_loader): masks = run_clipseg( images=batch["image"], prompt=prompt, clipseg_processor=clipseg_processor, clipseg_model=clipseg_model, clipseg_temp=clipseg_temp, device=device, ) for image_path, mask in zip(batch["image_path"], masks, strict=True): image_path = Path(image_path) out_path = image_path.parent / "masks" / (image_path.stem + ".png") out_path.parent.mkdir(exist_ok=True, parents=True) mask.save(out_path) print(f"Saved mask to: {out_path}") def main(): parser = argparse.ArgumentParser(description="Generate masks for a directory of images.") parser.add_argument("--dir", type=str, required=True, help="Directory containing images.") parser.add_argument( "--prompt", required=True, type=str, help="A short description of the thing you want to mask. E.g. 'a cat'.", ) parser.add_argument( "--clipseg-temp", type=float, default=1.0, help="Temperature applied to the CLIPSeg logits. Higher values cause the mask to be 'smoother' and include " "more of the background. Recommended range: 0.5 to 1.0.", ) parser.add_argument( "--batch-size", type=int, default=4, help="Batch size to use when processing images. Larger batch sizes may be faster but require more memory.", ) args = parser.parse_args() generate_masks(image_dir=args.dir, prompt=args.prompt, clipseg_temp=args.clipseg_temp, batch_size=args.batch_size) if __name__ == "__main__": main() ================================================ FILE: src/invoke_training/scripts/_experimental/masks/generate_masks_for_jsonl_dataset.py ================================================ import argparse from pathlib import Path import torch import torch.utils.data from tqdm import tqdm from invoke_training._shared.data.datasets.image_caption_jsonl_dataset import ( MASK_COLUMN_DEFAULT, ImageCaptionJsonlDataset, ) from invoke_training._shared.utils.jsonl import load_jsonl, save_jsonl from invoke_training.scripts._experimental.masks.clipseg import load_clipseg_model, run_clipseg, select_device def collate_fn(examples): """A collate_fn that combines images into a list rather than stacking into a tensor.""" return { "id": [example["id"] for example in examples], "image": [example["image"] for example in examples], } def validate_out_json_path(out_json_path: str | Path): out_json_path = Path(out_json_path) if out_json_path.exists(): raise FileExistsError(f"Output jsonl file '{out_json_path}' already exists.") if not out_json_path.suffix == ".jsonl": raise ValueError(f"Output jsonl file '{out_json_path}' must have a .jsonl extension.") @torch.no_grad() def generate_masks( in_jsonl_path: str, out_jsonl_path: str, image_column: str, caption_column: str, prompt: str, clipseg_temp: float, batch_size: int, ): """Generate masks for a .jsonl dataset.""" # Load the .jsonl dataset. dataset = ImageCaptionJsonlDataset( jsonl_path=in_jsonl_path, image_column=image_column, caption_column=caption_column ) print(f"Loaded dataset from '{in_jsonl_path}' with {len(dataset)} images.") data_loader = torch.utils.data.DataLoader(dataset, collate_fn=collate_fn, batch_size=batch_size, drop_last=False) # We also need the raw jsonl data. jsonl_data = load_jsonl(in_jsonl_path) # Prepare output locations. out_jsonl_path = Path(out_jsonl_path) validate_out_json_path(out_jsonl_path) out_masks_dir = out_jsonl_path.parent / "masks" out_masks_dir.mkdir(exist_ok=False, parents=True) clipseg_processor, clipseg_model = load_clipseg_model() device = select_device() # Process each image. for batch in tqdm(data_loader): masks = run_clipseg( images=batch["image"], prompt=prompt, clipseg_processor=clipseg_processor, clipseg_model=clipseg_model, clipseg_temp=clipseg_temp, device=device, ) for id, mask in zip(batch["id"], masks, strict=True): orig_image_path = Path(jsonl_data[int(id)][image_column]) out_mask_path: Path = out_masks_dir / (orig_image_path.stem + ".png") mask.save(out_mask_path) print(f"Saved mask to: {out_mask_path}") # Infer whether the mask path should be relative or absolute based on the image path. if orig_image_path.is_absolute(): jsonl_data[int(id)][MASK_COLUMN_DEFAULT] = str(out_mask_path.resolve()) else: jsonl_data[int(id)][MASK_COLUMN_DEFAULT] = str(out_mask_path.relative_to(out_jsonl_path.parent)) # Save the modified jsonl data. validate_out_json_path(out_jsonl_path) save_jsonl(jsonl_data, out_jsonl_path) print(f"Saved modified jsonl data to: {out_jsonl_path}") def main(): parser = argparse.ArgumentParser(description="Generate masks for a jsonl dataset.") parser.add_argument("--in-jsonl", type=str, required=True, help="Path to the dataset .jsonl file.") parser.add_argument( "--out-jsonl", type=str, required=True, help="Path to save the modified .jsonl file to. A masks/ directory will be created in the same directory as " "the .jsonl file to store the masks. The choice of whether to use relative or absolute paths for the masks is " "inferred from the image paths.", ) parser.add_argument( "--image-column", type=str, default="image", help="The name of the column containing image paths in the input .jsonl file.", ) parser.add_argument( "--caption-column", type=str, default="text", help="The name of the column containing captions in the input .jsonl file.", ) parser.add_argument( "--prompt", required=True, type=str, help="A short description of the thing you want to mask. E.g. 'a cat'.", ) parser.add_argument( "--clipseg-temp", type=float, default=1.0, help="Temperature applied to the CLIPSeg logits. Higher values cause the mask to be 'smoother' and include " "more of the background. Recommended range: 0.5 to 1.0.", ) parser.add_argument( "--batch-size", type=int, default=4, help="Batch size to use when processing images. Larger batch sizes may be faster but require more memory.", ) args = parser.parse_args() generate_masks( in_jsonl_path=args.in_jsonl, out_jsonl_path=args.out_jsonl, image_column=args.image_column, caption_column=args.caption_column, prompt=args.prompt, clipseg_temp=args.clipseg_temp, batch_size=args.batch_size, ) if __name__ == "__main__": main() ================================================ FILE: src/invoke_training/scripts/_experimental/rank_images.py ================================================ import argparse import os import time from pathlib import Path from typing import Literal import gradio as gr import yaml from pydantic import TypeAdapter from invoke_training._shared.data.datasets.image_pair_preference_dataset import ImagePairPreferenceDataset from invoke_training.config.pipeline_config import PipelineConfig def parse_args(): parser = argparse.ArgumentParser(description="Choose preferences from image pairs.") parser.add_argument( "-c", "--cfg-file", type=Path, required=True, help="Path to the YAML training config file. The internal dataset config will be used.", ) return parser.parse_args() def clip(val, min_val, max_val): return max(min(val, max_val), min_val) def main(): args = parse_args() # Load YAML config file. with open(args.cfg_file, "r") as f: cfg = yaml.safe_load(f) pipeline_adapter: TypeAdapter[PipelineConfig] = TypeAdapter(PipelineConfig) train_config = pipeline_adapter.validate_python(cfg) dataset_config = train_config.data_loader.dataset assert dataset_config.type == "IMAGE_PAIR_PREFERENCE_DATASET" metadata = ImagePairPreferenceDataset.load_metadata(dataset_config.dataset_dir) print(f"Launching UI to rank image pairs in '{dataset_config.dataset_dir}'.") def get_img_path(index: int, image_id: Literal["image_0", "image_1"]): return os.path.join(dataset_config.dataset_dir, metadata[index][image_id]) def get_state(index: int): img_0 = get_img_path(index, "image_0") img_1 = get_img_path(index, "image_1") prefer_0 = metadata[index]["prefer_0"] prefer_1 = metadata[index]["prefer_1"] caption = metadata[index]["prompt"] return [index, img_0, img_1, prefer_0, prefer_1, caption] def go_to_index(index: int): new_index = clip(index, 0, len(metadata) - 1) return get_state(new_index) def mark_prefer_0(index: int): metadata[index]["prefer_0"] = True metadata[index]["prefer_1"] = False # Step to next example. return go_to_index(index + 1) def mark_prefer_1(index: int): metadata[index]["prefer_0"] = False metadata[index]["prefer_1"] = True # Step to next example. return go_to_index(index + 1) def save_metadata(): timestamp = str(time.time()).replace(".", "_") metadata_file = f"metadata-{timestamp}.jsonl" metadata_path = ImagePairPreferenceDataset.save_metadata( metadata=metadata, dataset_dir=dataset_config.dataset_dir, metadata_file=metadata_file ) print(f"Saved metadata to '{metadata_path}'.") with gr.Blocks() as demo: index = gr.Number(value=-1, precision=0) with gr.Row(): img_0 = gr.Image(type="filepath", label="Image 0", interactive=False) img_1 = gr.Image(type="filepath", label="Image 1", interactive=False) caption = gr.Textbox(interactive=False, show_label=False) with gr.Row(): prefer_0 = gr.Checkbox(label="Prefer 0", interactive=False) prefer_1 = gr.Checkbox(label="Prefer 1", interactive=False) with gr.Row(): mark_prefer_0_button = gr.Button("Prefer 0") mark_prefer_1_button = gr.Button("Prefer 1") save_metadata_button = gr.Button("Save Metadata") index.change(go_to_index, inputs=[index], outputs=[index, img_0, img_1, prefer_0, prefer_1, caption]) mark_prefer_0_button.click( mark_prefer_0, inputs=[index], outputs=[index, img_0, img_1, prefer_0, prefer_1, caption] ) mark_prefer_1_button.click( mark_prefer_1, inputs=[index], outputs=[index, img_0, img_1, prefer_0, prefer_1, caption] ) save_metadata_button.click(save_metadata) demo.launch() if __name__ == "__main__": main() ================================================ FILE: src/invoke_training/scripts/convert_sd_lora_to_kohya_format.py ================================================ import argparse from pathlib import Path import torch from invoke_training._shared.stable_diffusion.lora_checkpoint_utils import ( convert_sd_peft_checkpoint_to_kohya_state_dict, ) def parse_args(): parser = argparse.ArgumentParser( description="Convert a Stable Diffusion LoRA checkpoint in PEFT format to kohya format." ) parser.add_argument( "--src-ckpt-dir", type=str, required=True, help="Path to the source checkpoint directory.", ) parser.add_argument( "--dst-ckpt-file", type=str, required=True, help="Path to the destination Kohya checkpoint file.", ) parser.add_argument( "--dtype", type=str, default="fp16", help="The precision to save the kohya state dict in. One of ['fp16', 'fp32'].", ) return parser.parse_args() def main(): args = parse_args() in_checkpoint_dir = Path(args.src_ckpt_dir) out_checkpoint_file = Path(args.dst_ckpt_file) if args.dtype == "fp32": dtype = torch.float32 elif args.dtype == "fp16": dtype = torch.float16 else: raise ValueError(f"Unsupported --dtype = '{args.dtype}'.") convert_sd_peft_checkpoint_to_kohya_state_dict( in_checkpoint_dir=in_checkpoint_dir, out_checkpoint_file=out_checkpoint_file, dtype=dtype ) print(f"Saved kohya checkpoint to '{out_checkpoint_file}'.") if __name__ == "__main__": main() ================================================ FILE: src/invoke_training/scripts/invoke_generate_images.py ================================================ import argparse from pathlib import Path from invoke_training._shared.stable_diffusion.model_loading_utils import PipelineVersionEnum from invoke_training._shared.tools.generate_images import generate_images def parse_args(): parser = argparse.ArgumentParser( description="Generate a dataset of images from a single prompt. (Typically used to generate prior " "preservation/regularization datasets.)" ) parser.add_argument( "-o", "--out-dir", type=str, required=True, help="Path to the directory where the images will be stored.", ) parser.add_argument( "-m", "--model", type=str, required=True, help="Name or path of the diffusers model to generate images with. Can be in diffusers format, or a single " "stable diffusion checkpoint file. (E.g. 'runwayml/stable-diffusion-v1-5', " "'stabilityai/stable-diffusion-xl-base-1.0', '/path/to/realisticVisionV51_v51VAE.safetensors', etc. )", ) parser.add_argument( "-v", "--variant", type=str, required=False, default=None, help="The Hugging Face Hub model variant to use. Only applies if `--model` is a Hugging Face Hub model name.", ) parser.add_argument( "-l", "--lora", type=str, nargs="*", help="LoRA models to apply to the base model. The LoRA weight can optionally be provided after a colon " "separator. E.g. `-l path/to/lora.bin:0.5 -l path/to/lora_2.safetensors`. ", ) parser.add_argument( "--ti", type=str, nargs="*", help="Paths(s) to Textual Inversion embeddings to apply to the base model.", ) parser.add_argument( "--sd-version", type=str, required=True, help="The Stable Diffusion version. One of: ['SD', 'SDXL'].", ) # One of --prompt or --prompt-file. group = parser.add_mutually_exclusive_group(required=True) group.add_argument("-p", "--prompt", type=str, help="The prompt to use for image generation.") group.add_argument("--prompt-file", type=str, help="A file containing prompts. One per line.") parser.add_argument( "--set-size", type=int, default=1, help="The number of images generated in each 'set' for a given prompt." ) parser.add_argument("--num-sets", type=int, default=1, help="The number of 'sets' to generate for each prompt.") parser.add_argument( "--height", type=int, required=True, help="The height of the generated images in pixels.", ) parser.add_argument( "--width", type=int, required=True, help="The width of the generated images in pixels.", ) parser.add_argument( "-s", "--seed", type=int, default=0, help="Seed for repeatability.", ) parser.add_argument( "--enable-cpu-offload", default=False, action="store_true", help="If True, models will be loaded onto the GPU one by one to conserve VRAM.", ) return parser.parse_args() def parse_lora_args(lora_args: list[str] | None) -> list[tuple[Path, int]]: loras: list[tuple[Path, int]] = [] lora_args = lora_args or [] for lora in lora_args: lora_split = lora.split(":") if len(lora_split) == 1: # If weight is not specified, assume 1.0. loras.append((Path(lora_split[0]), 1.0)) elif len(lora_split) == 2: loras.append((Path(lora_split[0]), float(lora_split[1]))) else: raise ValueError(f"Invalid lora argument syntax: '{lora}'.") return loras def parse_prompt_file(prompt_file: str) -> list[str]: with open(prompt_file) as f: prompts = f.readlines() return [p.strip() for p in prompts] def main(): args = parse_args() loras = parse_lora_args(args.lora) if args.prompt: prompts = [args.prompt] else: prompts = parse_prompt_file(args.prompt_file) print(f"Generating {args.num_sets} sets of {args.set_size} images for {len(prompts)} prompts in '{args.out_dir}'.") generate_images( out_dir=args.out_dir, model=args.model, hf_variant=args.variant, pipeline_version=PipelineVersionEnum(args.sd_version), prompts=prompts, set_size=args.set_size, num_sets=args.num_sets, height=args.height, width=args.width, loras=loras, ti_embeddings=args.ti, seed=args.seed, enable_cpu_offload=args.enable_cpu_offload, ) if __name__ == "__main__": main() ================================================ FILE: src/invoke_training/scripts/invoke_train.py ================================================ import argparse from pathlib import Path import yaml from pydantic import TypeAdapter from invoke_training.config.pipeline_config import PipelineConfig from invoke_training.pipelines.invoke_train import train def parse_args(): parser = argparse.ArgumentParser(description="Run a training pipeline.") parser.add_argument( "-c", "--cfg-file", type=Path, required=True, help="Path to the YAML training config file.", ) return parser.parse_args() def main(): args = parse_args() # Load YAML config file. with open(args.cfg_file, "r") as f: cfg = yaml.safe_load(f) pipeline_adapter: TypeAdapter[PipelineConfig] = TypeAdapter(PipelineConfig) train_config = pipeline_adapter.validate_python(cfg) train(train_config) if __name__ == "__main__": main() ================================================ FILE: src/invoke_training/scripts/invoke_train_ui.py ================================================ import argparse import uvicorn from invoke_training.ui.app import build_app def main(): parser = argparse.ArgumentParser() parser.add_argument( "--host", default="127.0.0.1", help="The server host. Set `--host 0.0.0.0` to make the app available on your network.", ) parser.add_argument("--port", default=8000, type=int, help="The server port.") args = parser.parse_args() app = build_app() uvicorn.run( app, host=args.host, port=args.port, ) if __name__ == "__main__": main() ================================================ FILE: src/invoke_training/scripts/invoke_visualize_data_loading.py ================================================ import argparse import os import time from pathlib import Path import numpy as np import torch import yaml from PIL import Image from pydantic import TypeAdapter from torch.utils.data import DataLoader from invoke_training._shared.data.data_loaders.dreambooth_sd_dataloader import build_dreambooth_sd_dataloader from invoke_training._shared.data.data_loaders.image_caption_sd_dataloader import ( build_image_caption_sd_dataloader, ) from invoke_training._shared.data.data_loaders.image_pair_preference_sd_dataloader import ( build_image_pair_preference_sd_dataloader, ) from invoke_training._shared.data.data_loaders.textual_inversion_sd_dataloader import ( build_textual_inversion_sd_dataloader, ) from invoke_training.config.pipeline_config import PipelineConfig def save_image(torch_image: torch.Tensor, out_path: Path): """Save a torch image to disk. Args: torch_image (torch.Tensor): Shape=(C, H, W). Pixel values are expected to be normalized in the range [-1.0, 1.0]. out_path (Path): The output path. """ np_image = torch_image.clone().detach().cpu().numpy() # Convert back to range [0, 1.0]. np_image = np_image * 0.5 + 0.5 # Convert back to range [0, 255]. np_image *= 255 # Move channel axis from first dimension to last dimension. np_image = np.moveaxis(np_image, 0, -1) # Cast to np.uint8. np_image = np_image.astype(np.uint8) Image.fromarray(np_image).save(out_path) def parse_args(): parser = argparse.ArgumentParser(description="Visualize data loading from a pipeline config.") parser.add_argument( "-c", "--cfg-file", type=Path, required=True, help="Path to the YAML training config file.", ) return parser.parse_args() def visualize(data_loader: DataLoader): out_dir = Path(f"out_{str(time.time()).replace('.', '-')}/") os.makedirs(out_dir) for batch_idx, batch in enumerate(data_loader): print(f"Batch {batch_idx}:") batch_path = out_dir / f"batch_{batch_idx}" batch_path.mkdir() saved_images = [] for k, v in batch.items(): if isinstance(v, torch.Tensor): print(f"{k}: Tensor.shape={v.shape}") if len(v.shape) == 4 and v.shape[1] == 3: # This is likely a batch of RGB images, so we save them to disk. for i in range(v.shape[0]): out_path = batch_path / f"{k}_{i}.png" save_image(v[i, ...], out_path) saved_images.append(out_path) else: print(f"{k}: {v}") for saved_image in saved_images: print(f"Saved image to '{saved_image}'.") _ = input("\n\nPress Enter to continue to next batch...\n") def main(): args = parse_args() # Load YAML config file. with open(args.cfg_file, "r") as f: cfg = yaml.safe_load(f) pipeline_adapter: TypeAdapter[PipelineConfig] = TypeAdapter(PipelineConfig) train_config = pipeline_adapter.validate_python(cfg) data_loader_config = train_config.data_loader if data_loader_config.type == "IMAGE_CAPTION_SD_DATA_LOADER": data_loader = build_image_caption_sd_dataloader( config=data_loader_config, batch_size=train_config.train_batch_size, shuffle=False, ) elif data_loader_config.type == "TEXTUAL_INVERSION_SD_DATA_LOADER": data_loader = build_textual_inversion_sd_dataloader( config=data_loader_config, placeholder_token="", batch_size=train_config.train_batch_size, shuffle=False, ) elif data_loader_config.type == "DREAMBOOTH_SD_DATA_LOADER": data_loader = build_dreambooth_sd_dataloader( config=data_loader_config, batch_size=train_config.train_batch_size, shuffle=False, sequential_batching=False, ) elif data_loader_config.type == "IMAGE_PAIR_PREFERENCE_SD_DATA_LOADER": data_loader = build_image_pair_preference_sd_dataloader( config=data_loader_config, batch_size=train_config.train_batch_size, shuffle=False, ) else: raise ValueError(f"Unexpected data loader type: '{data_loader_config.type}'.") visualize(data_loader) if __name__ == "__main__": main() ================================================ FILE: src/invoke_training/scripts/utils/image_dir_dataset.py ================================================ import os import typing import torch from PIL import Image class ImageDirDataset(torch.utils.data.Dataset): """A simple dataset that loads images from a directory.""" def __init__( self, dataset_dir: str, image_extensions: typing.Optional[list[str]] = None, ): super().__init__() if image_extensions is None: image_extensions = [".png", ".jpg", ".jpeg"] image_extensions = [ext.lower() for ext in image_extensions] # Determine the list of image paths to include in the dataset. self._image_paths: list[str] = [] for image_file in os.listdir(dataset_dir): image_path = os.path.join(dataset_dir, image_file) if os.path.isfile(image_path) and os.path.splitext(image_path)[1].lower() in image_extensions: self._image_paths.append(image_path) self._image_paths.sort() def _load_image(self, image_path: str) -> Image.Image: # We call `convert("RGB")` to drop the alpha channel from RGBA images, or to repeat channels for greyscale # images. return Image.open(image_path).convert("RGB") def __len__(self) -> int: return len(self._image_paths) def __getitem__(self, idx: int): image_path = self._image_paths[idx] image = self._load_image(image_path) return {"image_path": self._image_paths[idx], "image": image} def list_collate_fn(examples): """Custom collate_fn that combines images into a list rather than stacking into a tensor. This is what the Moondream model expects. """ return { "image": [example["image"] for example in examples], "image_path": [example["image_path"] for example in examples], } ================================================ FILE: src/invoke_training/ui/__init__.py ================================================ ================================================ FILE: src/invoke_training/ui/app.py ================================================ from pathlib import Path import gradio as gr from fastapi import FastAPI from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles from invoke_training.ui.pages.data_page import DataPage from invoke_training.ui.pages.training_page import TrainingPage def build_app(): training_page = TrainingPage() data_page = DataPage() app = FastAPI() @app.get("/") async def root(): index_path = Path(__file__).parent / "index.html" return FileResponse(index_path) app.mount("/assets", StaticFiles(directory=Path(__file__).parent.parent / "assets"), name="assets") app = gr.mount_gradio_app(app, training_page.app(), "/train", app_kwargs={"favicon_path": "/assets/favicon.png"}) app = gr.mount_gradio_app(app, data_page.app(), "/data", app_kwargs={"favicon_path": "/assets/favicon.png"}) return app ================================================ FILE: src/invoke_training/ui/config_groups/__init__.py ================================================ ================================================ FILE: src/invoke_training/ui/config_groups/aspect_ratio_bucket_config_group.py ================================================ from typing import Any import gradio as gr from invoke_training.config.data.data_loader_config import AspectRatioBucketConfig from invoke_training.ui.config_groups.ui_config_element import UIConfigElement class AspectRatioBucketConfigGroup(UIConfigElement): def __init__(self): gr.Markdown( "Aspect ratio bucket resolutions are generated as follows:\n" "- Iterate over 'first' dimension values from `start_dim` to `end_dim` in steps of size `divisible_by`.\n" "- Calculate the 'second' dimension to be as close as possible to the total number of pixels in " "`target_resolution`, while still being divisible by `divisible_by`." ) self.enabled = gr.Checkbox(label="Use Aspect Ratio Bucketing", interactive=True) self.target_resolution = gr.Number(label="target_resolution", interactive=True, precision=0) self.start_dim = gr.Number(label="start_dimension", interactive=True, precision=0) self.end_dim = gr.Number(label="end_dimension", interactive=True, precision=0) self.divisible_by = gr.Number(label="divisible_by", interactive=True, precision=0) def update_ui_components_with_config_data( self, config: AspectRatioBucketConfig | None ) -> dict[gr.components.Component, Any]: enabled = True if config is None: enabled = False # We just construct this config to hold default values. config = AspectRatioBucketConfig(target_resolution=512, start_dim=256, end_dim=768, divisible_by=64) update_dict = { self.enabled: enabled, self.target_resolution: config.target_resolution, self.start_dim: config.start_dim, self.end_dim: config.end_dim, self.divisible_by: config.divisible_by, } return update_dict def update_config_with_ui_component_data( self, orig_config: AspectRatioBucketConfig | None, ui_data: dict[gr.components.Component, Any] ) -> AspectRatioBucketConfig | None: # TODO: Use orig_config? if not ui_data.pop(self.enabled): # Pop fields from ui_data so that upstream code knows that the fields were handled. ui_data.pop(self.target_resolution) ui_data.pop(self.start_dim) ui_data.pop(self.end_dim) ui_data.pop(self.divisible_by) return None new_config = AspectRatioBucketConfig( target_resolution=ui_data.pop(self.target_resolution), start_dim=ui_data.pop(self.start_dim), end_dim=ui_data.pop(self.end_dim), divisible_by=ui_data.pop(self.divisible_by), ) return new_config ================================================ FILE: src/invoke_training/ui/config_groups/base_pipeline_config_group.py ================================================ from typing import Any import gradio as gr from invoke_training.config.base_pipeline_config import BasePipelineConfig from invoke_training.config.pipeline_config import PipelineConfig from invoke_training.ui.config_groups.ui_config_element import UIConfigElement class BasePipelineConfigGroup(UIConfigElement): def __init__(self): self.base_output_dir = gr.Textbox( label="Base Output Directory", info="The base output directory where the training outputs (model checkpoints, logs," " intermediate predictions) will be written.", interactive=True, ) with gr.Row(): with gr.Column(): self.max_train_steps_or_epochs_dropdown = gr.Dropdown( label="Training Length", info="Train for a fixed number of gradient update steps or epochs.", choices=["max_train_steps", "max_train_epochs"], interactive=True, ) self.max_train_steps_or_epochs = gr.Number(label="Steps or Epochs", precision=0, interactive=True) with gr.Column(): self.save_every_n_steps_or_epochs_dropdown = gr.Dropdown( label="Checkpoint Save Frequency", info="Save a checkpoint every N gradient update steps or epochs.", choices=["save_every_n_steps", "save_every_n_epochs"], interactive=True, ) self.save_every_n_steps_or_epochs = gr.Number(label="Steps or Epochs", precision=0, interactive=True) with gr.Column(): self.validate_every_n_steps_or_epochs_dropdown = gr.Dropdown( label="Validation Frequency", info="Save validation images every N gradient update steps or epochs.", choices=["validate_every_n_steps", "validate_every_n_epochs"], interactive=True, ) self.validate_every_n_steps_or_epochs = gr.Number( label="Steps or Epochs", precision=0, interactive=True ) self.seed = gr.Number( label="Seed", info="Set to any constant integer for consistent training results. If set to null, training" " will be non-deterministic.", precision=0, interactive=True, ) def update_ui_components_with_config_data(self, config: BasePipelineConfig) -> dict[gr.components.Component, Any]: if config.max_train_epochs is not None: max_train_steps_or_epochs_dropdown = "max_train_epochs" max_train_steps_or_epochs = config.max_train_epochs elif config.max_train_steps is not None: max_train_steps_or_epochs_dropdown = "max_train_steps" max_train_steps_or_epochs = config.max_train_steps else: raise ValueError("One of max_train_epochs or max_train_steps must be set.") if config.save_every_n_epochs is not None: save_every_n_steps_or_epochs_dropdown = "save_every_n_epochs" save_every_n_steps_or_epochs = config.save_every_n_epochs elif config.save_every_n_steps is not None: save_every_n_steps_or_epochs_dropdown = "save_every_n_steps" save_every_n_steps_or_epochs = config.save_every_n_steps else: raise ValueError("One of save_every_n_epochs or save_every_n_steps must be set.") if config.validate_every_n_epochs is not None: validate_every_n_steps_or_epochs_dropdown = "validate_every_n_epochs" validate_every_n_steps_or_epochs = config.validate_every_n_epochs elif config.validate_every_n_steps is not None: validate_every_n_steps_or_epochs_dropdown = "validate_every_n_steps" validate_every_n_steps_or_epochs = config.validate_every_n_steps else: raise ValueError("One of validate_every_n_epochs or validate_every_n_steps must be set.") return { self.seed: config.seed, self.base_output_dir: config.base_output_dir, self.max_train_steps_or_epochs_dropdown: max_train_steps_or_epochs_dropdown, self.max_train_steps_or_epochs: max_train_steps_or_epochs, self.save_every_n_steps_or_epochs_dropdown: save_every_n_steps_or_epochs_dropdown, self.save_every_n_steps_or_epochs: save_every_n_steps_or_epochs, self.validate_every_n_steps_or_epochs_dropdown: validate_every_n_steps_or_epochs_dropdown, self.validate_every_n_steps_or_epochs: validate_every_n_steps_or_epochs, } def update_config_with_ui_component_data( self, orig_config: PipelineConfig, ui_data: dict[gr.components.Component, Any] ) -> PipelineConfig: new_config = orig_config.model_copy(deep=True) new_config.seed = ui_data.pop(self.seed) new_config.base_output_dir = ui_data.pop(self.base_output_dir) if ui_data.pop(self.max_train_steps_or_epochs_dropdown) == "max_train_epochs": new_config.max_train_epochs = ui_data.pop(self.max_train_steps_or_epochs) new_config.max_train_steps = None else: new_config.max_train_steps = ui_data.pop(self.max_train_steps_or_epochs) new_config.max_train_epochs = None if ui_data.pop(self.save_every_n_steps_or_epochs_dropdown) == "save_every_n_epochs": new_config.save_every_n_epochs = ui_data.pop(self.save_every_n_steps_or_epochs) new_config.save_every_n_steps = None else: new_config.save_every_n_steps = ui_data.pop(self.save_every_n_steps_or_epochs) new_config.save_every_n_epochs = None if ui_data.pop(self.validate_every_n_steps_or_epochs_dropdown) == "validate_every_n_epochs": new_config.validate_every_n_epochs = ui_data.pop(self.validate_every_n_steps_or_epochs) new_config.validate_every_n_steps = None else: new_config.validate_every_n_steps = ui_data.pop(self.validate_every_n_steps_or_epochs) new_config.validate_every_n_epochs = None return new_config ================================================ FILE: src/invoke_training/ui/config_groups/dataset_config_group.py ================================================ from typing import Any import gradio as gr from invoke_training.config.data.dataset_config import ( HFHubImageCaptionDatasetConfig, ImageCaptionDatasetConfig, ImageCaptionDirDatasetConfig, ImageCaptionJsonlDatasetConfig, ImageDirDatasetConfig, ) from invoke_training.ui.config_groups.ui_config_element import UIConfigElement ALL_DATASET_TYPES = [ "HF_HUB_IMAGE_CAPTION_DATASET", "IMAGE_CAPTION_JSONL_DATASET", "IMAGE_CAPTION_DIR_DATASET", "IMAGE_DIR_DATASET", ] class HFHubImageCaptionDatasetConfigGroup(UIConfigElement): def __init__(self): self.dataset_name = gr.Textbox( label="Dataset Name", info="Hugging Face Dataset Name (e.g., owner/RepoID).", interactive=True ) with gr.Row(): self.dataset_config_name = gr.Textbox( label="Dataset Config Name (Optional)", info="The Hugging Face dataset config name. Leave as None if there's only one config.", interactive=True, ) with gr.Row(): self.hf_cache_dir = gr.Textbox( label="Cache Directory", info="The Hugging Face cache directory to use for dataset downloads. If None, the default value" " will be used (usually '~/.cache/huggingface/datasets').", interactive=True, ) # self.image_column = gr.Textbox(label="image_column", interactive=True) # self.caption_column = gr.Textbox(label="caption_column", interactive=True) def update_ui_components_with_config_data( self, config: HFHubImageCaptionDatasetConfig | None ) -> dict[gr.components.Component, Any]: return { self.dataset_name: config.dataset_name if config else "", self.dataset_config_name: config.dataset_config_name if config else None, self.hf_cache_dir: config.hf_cache_dir if config else None, # self.image_column: config.image_column, # self.caption_column: config.caption_column, } def update_config_with_ui_component_data( self, orig_config: HFHubImageCaptionDatasetConfig | None, ui_data: dict[gr.components.Component, Any] ) -> HFHubImageCaptionDatasetConfig: assert orig_config is None # new_config = orig_config.model_copy(deep=True) new_config = HFHubImageCaptionDatasetConfig( dataset_name=ui_data.pop(self.dataset_name), dataset_config_name=ui_data.pop(self.dataset_config_name) or None, hf_cache_dir=ui_data.pop(self.hf_cache_dir) or None, # image_column=ui_data.pop(self.image_column), # caption_column=ui_data.pop(self.caption_column), ) return new_config class ImageCaptionJsonlDatasetConfigGroup(UIConfigElement): def __init__(self): self.jsonl_path = gr.Textbox(label="jsonl_path", info="Path to the dataset `.jsonl` file.", interactive=True) self.image_column = gr.Textbox( label="image_column", info="The name of the field in the `.jsonl` containing image file paths.", interactive=True, ) self.caption_column = gr.Textbox( label="caption_column", info="The name of the field in the `.jsonl` containing image captions.", interactive=True, ) self.keep_in_memory = gr.Checkbox( label="keep_in_memory", info="If True, the entire dataset will be kept in RAM. This increases speed for small datasets at the " "cost of higher RAM usage.", interactive=True, ) def update_ui_components_with_config_data( self, config: ImageCaptionJsonlDatasetConfig | None ) -> dict[gr.components.Component, Any]: if config is None: # We just construct this so that we can use its default values. config = ImageCaptionJsonlDatasetConfig(jsonl_path="") return { self.jsonl_path: config.jsonl_path, self.image_column: config.image_column, self.caption_column: config.caption_column, self.keep_in_memory: config.keep_in_memory, } def update_config_with_ui_component_data( self, orig_config: ImageCaptionJsonlDatasetConfig | None, ui_data: dict[gr.components.Component, Any] ) -> ImageCaptionJsonlDatasetConfig: assert orig_config is None # new_config = orig_config.model_copy(deep=True) new_config = ImageCaptionJsonlDatasetConfig( jsonl_path=ui_data.pop(self.jsonl_path), image_column=ui_data.pop(self.image_column), caption_column=ui_data.pop(self.caption_column), keep_in_memory=ui_data.pop(self.keep_in_memory), ) return new_config class ImageCaptionDirDatasetConfigGroup(UIConfigElement): def __init__(self): with gr.Row(): self.dataset_dir = gr.Textbox( label="dataset_dir", info="The path to the dataset directory.", interactive=True ) with gr.Row(): self.keep_in_memory = gr.Checkbox( label="keep_in_memory", info="If True, the entire dataset will be kept in RAM. This increases speed for small datasets at the " "cost of higher RAM usage.", interactive=True, ) def update_ui_components_with_config_data( self, config: ImageCaptionDirDatasetConfig | None ) -> dict[gr.components.Component, Any]: return { self.dataset_dir: config.dataset_dir if config else "", self.keep_in_memory: config.keep_in_memory if config else False, } def update_config_with_ui_component_data( self, orig_config: ImageCaptionDirDatasetConfig | None, ui_data: dict[gr.components.Component, Any] ) -> ImageCaptionDirDatasetConfig: assert orig_config is None # new_config = orig_config.model_copy(deep=True) new_config = ImageCaptionDirDatasetConfig( dataset_dir=ui_data.pop(self.dataset_dir), keep_in_memory=ui_data.pop(self.keep_in_memory) ) return new_config class ImageDirDatasetConfigGroup(UIConfigElement): def __init__(self): with gr.Row(): self.dataset_dir = gr.Textbox( label="dataset_dir", info="The path to the dataset directory.", interactive=True ) with gr.Row(): self.keep_in_memory = gr.Checkbox( label="keep_in_memory", info="If True, the entire dataset will be kept in RAM. This increases speed for small datasets at the " "cost of higher RAM usage.", interactive=True, ) def update_ui_components_with_config_data( self, config: ImageDirDatasetConfig | None ) -> dict[gr.components.Component, Any]: return { self.dataset_dir: config.dataset_dir if config else "", self.keep_in_memory: config.keep_in_memory if config else False, } def update_config_with_ui_component_data( self, orig_config: ImageDirDatasetConfig | None, ui_data: dict[gr.components.Component, Any] ) -> ImageDirDatasetConfig: assert orig_config is None # new_config = orig_config.model_copy(deep=True) new_config = ImageDirDatasetConfig( dataset_dir=ui_data.pop(self.dataset_dir), keep_in_memory=ui_data.pop(self.keep_in_memory) ) return new_config class DatasetConfigGroup(UIConfigElement): def __init__(self, allowed_types: list[str]): self.type = gr.Dropdown( choices=[t for t in ALL_DATASET_TYPES if t in allowed_types], label="Dataset Type", info="The type of dataset to use for training. See " "https://invoke-ai.github.io/invoke-training/concepts/dataset_formats/ for a description of each format.", interactive=True, ) with gr.Group() as hf_hub_image_caption_dataset_config_group: self.hf_hub_image_caption_dataset_config = HFHubImageCaptionDatasetConfigGroup() self.hf_hub_image_caption_dataset_config_group = hf_hub_image_caption_dataset_config_group with gr.Group() as image_caption_jsonl_dataset_config_group: self.image_caption_jsonl_dataset_config = ImageCaptionJsonlDatasetConfigGroup() self.image_caption_jsonl_dataset_config_group = image_caption_jsonl_dataset_config_group with gr.Group() as image_caption_dir_dataset_config_group: self.image_caption_dir_dataset_config = ImageCaptionDirDatasetConfigGroup() self.image_caption_dir_dataset_config_group = image_caption_dir_dataset_config_group with gr.Group() as image_dir_dataset_config_group: self.image_dir_dataset_config = ImageDirDatasetConfigGroup() self.image_dir_dataset_config_group = image_dir_dataset_config_group self.type.change( self._on_type_change, inputs=[self.type], outputs=[ self.hf_hub_image_caption_dataset_config_group, self.image_caption_jsonl_dataset_config_group, self.image_caption_dir_dataset_config_group, self.image_dir_dataset_config_group, ], ) def _on_type_change(self, type: str): return { self.hf_hub_image_caption_dataset_config_group: gr.Group(visible=type == "HF_HUB_IMAGE_CAPTION_DATASET"), self.image_caption_jsonl_dataset_config_group: gr.Group(visible=type == "IMAGE_CAPTION_JSONL_DATASET"), self.image_caption_dir_dataset_config_group: gr.Group(visible=type == "IMAGE_CAPTION_DIR_DATASET"), self.image_dir_dataset_config_group: gr.Group(visible=type == "IMAGE_DIR_DATASET"), } def update_ui_components_with_config_data( self, config: ImageCaptionDatasetConfig ) -> dict[gr.components.Component, Any]: update_dict = { self.type: config.type, self.hf_hub_image_caption_dataset_config_group: gr.Group( visible=config.type == "HF_HUB_IMAGE_CAPTION_DATASET" ), self.image_caption_jsonl_dataset_config_group: gr.Group( visible=config.type == "IMAGE_CAPTION_JSONL_DATASET" ), self.image_caption_dir_dataset_config_group: gr.Group(visible=config.type == "IMAGE_CAPTION_DIR_DATASET"), self.image_dir_dataset_config_group: gr.Group(visible=config.type == "IMAGE_DIR_DATASET"), } update_dict.update( self.hf_hub_image_caption_dataset_config.update_ui_components_with_config_data( config if config.type == "HF_HUB_IMAGE_CAPTION_DATASET" else None ) ) update_dict.update( self.image_caption_jsonl_dataset_config.update_ui_components_with_config_data( config if config.type == "IMAGE_CAPTION_JSONL_DATASET" else None ) ) update_dict.update( self.image_caption_dir_dataset_config.update_ui_components_with_config_data( config if config.type == "IMAGE_CAPTION_DIR_DATASET" else None ) ) update_dict.update( self.image_dir_dataset_config.update_ui_components_with_config_data( config if config.type == "IMAGE_DIR_DATASET" else None ) ) return update_dict def update_config_with_ui_component_data( self, orig_config: ImageCaptionDatasetConfig, ui_data: dict[gr.components.Component, Any] ) -> ImageCaptionDatasetConfig: # TODO: Use orig_config. new_config_hf_hub = self.hf_hub_image_caption_dataset_config.update_config_with_ui_component_data(None, ui_data) new_config_jsonl = self.image_caption_jsonl_dataset_config.update_config_with_ui_component_data(None, ui_data) new_config_image_caption_dir = self.image_caption_dir_dataset_config.update_config_with_ui_component_data( None, ui_data ) new_config_image_dir = self.image_dir_dataset_config.update_config_with_ui_component_data(None, ui_data) type = ui_data.pop(self.type) if type == "HF_HUB_IMAGE_CAPTION_DATASET": new_config = new_config_hf_hub elif type == "IMAGE_CAPTION_JSONL_DATASET": new_config = new_config_jsonl elif type == "IMAGE_CAPTION_DIR_DATASET": new_config = new_config_image_caption_dir elif type == "IMAGE_DIR_DATASET": new_config = new_config_image_dir else: raise ValueError(f"Unknown dataset type: {type}") return new_config ================================================ FILE: src/invoke_training/ui/config_groups/flux_lora_config_group.py ================================================ import typing import gradio as gr from invoke_training.pipelines.flux.lora.config import FluxLoraConfig from invoke_training.ui.config_groups.base_pipeline_config_group import BasePipelineConfigGroup from invoke_training.ui.config_groups.image_caption_sd_data_loader_config_group import ( ImageCaptionSDDataLoaderConfigGroup, ) from invoke_training.ui.config_groups.optimizer_config_group import OptimizerConfigGroup from invoke_training.ui.config_groups.ui_config_element import UIConfigElement from invoke_training.ui.utils.utils import get_typing_literal_options class FluxLoraConfigGroup(UIConfigElement): def __init__(self): """The Flux LoRA configs.""" gr.Markdown("## Basic Configs") with gr.Row(): with gr.Column(scale=1): with gr.Tab("Base Model"): self.model = gr.Textbox( label="Model", info="The base model. Can be a Hugging Face Hub model name, or a path to a local model (in " "diffusers or checkpoint format).", type="text", interactive=True, ) # Flux model doesn't use hf_variant with gr.Column(scale=3): with gr.Tab("Training Outputs"): self.base_pipeline_config_group = BasePipelineConfigGroup() self.max_checkpoints = gr.Number( label="Maximum Number of Checkpoints", info="The maximum number of checkpoints to keep on disk from this training run. Earlier " "checkpoints will be deleted to respect this limit.", interactive=True, precision=0, ) gr.Markdown("## Data Configs") self.image_caption_sd_data_loader_config_group = ImageCaptionSDDataLoaderConfigGroup() gr.Markdown("## Optimizer Configs") self.optimizer_config_group = OptimizerConfigGroup() gr.Markdown("## Scheduler Configs") with gr.Row(): with gr.Column(): self.lr_scheduler = gr.Dropdown( label="Learning Rate Scheduler", choices=get_typing_literal_options(FluxLoraConfig, "lr_scheduler"), interactive=True, ) self.lr_warmup_steps = gr.Number( label="Learning Rate Warmup Steps", info="Number of steps for the warmup in the lr scheduler.", interactive=True, precision=0, ) gr.Markdown("## General Training Configs") with gr.Tab("Core"): with gr.Row(): self.train_transformer = gr.Checkbox(label="Train Transformer", interactive=True) with gr.Row(): self.transformer_learning_rate = gr.Number( label="Transformer Learning Rate", info="The transformer learning rate. Set to 0 or leave empty to inherit from the base optimizer " "learning rate.", interactive=True, ) with gr.Row(): self.gradient_accumulation_steps = gr.Number( label="Gradient Accumulation Steps", info="Number of updates steps to accumulate before performing a backward/update pass.", interactive=True, precision=0, ) self.gradient_checkpointing = gr.Checkbox( label="Gradient Checkpointing", info="Whether to use gradient checkpointing to save memory at the expense of slower backward pass.", interactive=True, ) # Training/saving/validating steps/epochs are handled by BasePipelineConfigGroup with gr.Tab("Advanced"): with gr.Column(): self.lora_rank_dim = gr.Number( label="LoRA Rank Dim", info="The rank dimension to use for the LoRA layers. Increasing the rank dimension increases" " the model's expressivity, but also increases the size of the generated LoRA model.", interactive=True, precision=0, ) self.min_snr_gamma = gr.Number( label="Minimum SNR Gamma", info="min_snr_gamma acts like an an upper bound on the weight of samples with low noise " "levels. If None, then Min-SNR weighting will not be applied. If enabled, the recommended " "value is min_snr gamma = 5.0.", interactive=True, ) self.max_grad_norm = gr.Number( label="Max Gradient Norm", info="Max gradient norm for clipping. Set to 0 or leave empty for no clipping (null).", interactive=True, ) self.train_batch_size = gr.Number( label="Batch Size", info="The Training Batch Size - Higher values require increasing amounts of VRAM.", precision=0, interactive=True, ) self.weight_dtype = gr.Dropdown( label="Weight Data Type", choices=get_typing_literal_options(FluxLoraConfig, "weight_dtype"), info="The data type to use for model weights during training.", interactive=True, ) self.mixed_precision = gr.Dropdown( label="Mixed Precision", choices=get_typing_literal_options(FluxLoraConfig, "mixed_precision"), info="The mixed precision mode to use.", interactive=True, ) self.lora_checkpoint_format = gr.Dropdown( label="LoRA Checkpoint Format", choices=get_typing_literal_options(FluxLoraConfig, "lora_checkpoint_format"), info="The format of the LoRA checkpoint to save.", interactive=True, ) self.timestep_sampler = gr.Dropdown( label="Timestep Sampler", choices=get_typing_literal_options(FluxLoraConfig, "timestep_sampler"), info="The timestep sampler to use.", interactive=True, ) self.discrete_flow_shift = gr.Number( label="Discrete Flow Shift", info="The shift parameter for the discrete flow. Only used if timestep_sampler is 'shift'.", interactive=True, ) self.sigmoid_scale = gr.Number( label="Sigmoid Scale", info="The scale parameter for the sigmoid function. Only used if timestep_sampler is 'shift'.", interactive=True, ) self.lora_scale = gr.Number( label="LoRA Scale", info="The scale parameter for the LoRA layers.", interactive=True, ) self.guidance_scale = gr.Number( label="Guidance Scale", info="The guidance scale for the Flux model.", interactive=True, ) self.use_masks = gr.Checkbox( label="Use Masks", info="If True, image masks will be applied to weight the loss during training. The dataset must " "contain masks for this feature to be used.", interactive=True, ) self.prediction_type = gr.Dropdown( label="Prediction Type", choices=["epsilon", "v_prediction", None], info="The prediction type that will be used for training.", interactive=True, ) gr.Markdown("## Validation") with gr.Group(): self.validation_prompts = gr.Textbox( label="Validation Prompts", info="Enter one validation prompt per line. Optionally, add negative prompts after a '[NEG]' " "delimiter. For example: `positive prompt[NEG]negative prompt`. ", lines=5, interactive=True, ) self.num_validation_images_per_prompt = gr.Number( label="# of Validation Images to Generate per Prompt", precision=0, interactive=True ) def get_ui_output_components(self) -> list[gr.components.Component]: # Get our own components components = [ self.model, self.train_transformer, self.transformer_learning_rate, self.gradient_accumulation_steps, self.gradient_checkpointing, self.lr_scheduler, self.lr_warmup_steps, self.lora_rank_dim, self.min_snr_gamma, self.max_grad_norm, self.train_batch_size, self.weight_dtype, self.mixed_precision, self.lora_checkpoint_format, self.timestep_sampler, self.discrete_flow_shift, self.sigmoid_scale, self.lora_scale, self.guidance_scale, self.use_masks, self.prediction_type, # These are not UI components but need to be preserved # self.flux_lora_target_modules, # self.text_encoder_lora_target_modules, self.validation_prompts, self.num_validation_images_per_prompt, self.max_checkpoints, ] # Add components from nested config groups components.extend(self.base_pipeline_config_group.get_ui_output_components()) components.extend(self.image_caption_sd_data_loader_config_group.get_ui_output_components()) components.extend(self.optimizer_config_group.get_ui_output_components()) return components def update_ui_components_with_config_data( self, config: FluxLoraConfig ) -> dict[gr.components.Component, typing.Any]: try: update_dict = { self.model: config.model, self.train_transformer: config.train_transformer, self.transformer_learning_rate: config.transformer_learning_rate, self.gradient_accumulation_steps: config.gradient_accumulation_steps, self.gradient_checkpointing: config.gradient_checkpointing, self.lr_scheduler: config.lr_scheduler, self.lr_warmup_steps: config.lr_warmup_steps, self.lora_rank_dim: config.lora_rank_dim, self.min_snr_gamma: config.min_snr_gamma, self.max_grad_norm: config.max_grad_norm, self.train_batch_size: config.train_batch_size, self.weight_dtype: config.weight_dtype, self.mixed_precision: config.mixed_precision, self.lora_checkpoint_format: config.lora_checkpoint_format, self.timestep_sampler: config.timestep_sampler, self.discrete_flow_shift: config.discrete_flow_shift, self.sigmoid_scale: config.sigmoid_scale, self.lora_scale: config.lora_scale, self.guidance_scale: config.guidance_scale, self.use_masks: config.use_masks, self.prediction_type: config.prediction_type, self.validation_prompts: config.validation_prompts, self.num_validation_images_per_prompt: config.num_validation_images_per_prompt, self.max_checkpoints: config.max_checkpoints, } # Update with nested config groups try: update_dict.update(self.base_pipeline_config_group.update_ui_components_with_config_data(config)) except Exception as e: print(f"Error updating base pipeline config: {e}") try: update_dict.update(self.optimizer_config_group.update_ui_components_with_config_data(config.optimizer)) except Exception as e: print(f"Error updating optimizer config: {e}") try: update_dict.update( self.image_caption_sd_data_loader_config_group.update_ui_components_with_config_data( config.data_loader ) ) except Exception as e: print(f"Error updating data loader config: {e}") # Sanity check to catch if we accidentally forget to update a UI component. # We'll skip this check for now as it's causing issues with nested components # assert set(update_dict.keys()) == set(self.get_ui_output_components()) return update_dict except Exception as e: print(f"Error in update_ui_components_with_config_data: {e}") # Return a minimal update dict to avoid UI errors return {self.model: config.model} def update_config_with_ui_component_data( # noqa: C901 self, orig_config: FluxLoraConfig, ui_data: dict[gr.components.Component, typing.Any] ) -> FluxLoraConfig: try: # Handle the case where orig_config might be None if orig_config is None: from invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig from invoke_training.pipelines.flux.lora.config import FluxLoraConfig # Create a default config orig_config = FluxLoraConfig( model="black-forest-labs/FLUX.1-dev", optimizer=AdamOptimizerConfig(), ) new_config = orig_config.model_copy(deep=True) # Create a copy of ui_data to avoid modifying the original ui_data_copy = ui_data.copy() # Helper function to safely pop values from ui_data def safe_pop(component, default=None): try: return ui_data_copy.pop(component) except (KeyError, TypeError) as e: print(f"Error popping {component}: {e}") return default # Set basic properties new_config.model = safe_pop(self.model, new_config.model) new_config.train_transformer = safe_pop(self.train_transformer, new_config.train_transformer) # Note: train_text_encoder and text_encoder_learning_rate are not supported for Flux LoRA transformer_lr_value = safe_pop(self.transformer_learning_rate, new_config.transformer_learning_rate) new_config.transformer_learning_rate = None if transformer_lr_value == 0 else transformer_lr_value new_config.gradient_accumulation_steps = safe_pop( self.gradient_accumulation_steps, new_config.gradient_accumulation_steps ) new_config.gradient_checkpointing = safe_pop(self.gradient_checkpointing, new_config.gradient_checkpointing) # Training/saving/validating steps/epochs are handled by BasePipelineConfigGroup new_config.lr_scheduler = safe_pop(self.lr_scheduler, new_config.lr_scheduler) new_config.lr_warmup_steps = safe_pop(self.lr_warmup_steps, new_config.lr_warmup_steps) new_config.lora_rank_dim = safe_pop(self.lora_rank_dim, new_config.lora_rank_dim) new_config.min_snr_gamma = safe_pop(self.min_snr_gamma, new_config.min_snr_gamma) max_grad_norm_value = safe_pop(self.max_grad_norm, new_config.max_grad_norm) new_config.max_grad_norm = None if max_grad_norm_value == 0 else max_grad_norm_value new_config.train_batch_size = safe_pop(self.train_batch_size, new_config.train_batch_size) new_config.weight_dtype = safe_pop(self.weight_dtype, new_config.weight_dtype) new_config.mixed_precision = safe_pop(self.mixed_precision, new_config.mixed_precision) new_config.lora_checkpoint_format = safe_pop(self.lora_checkpoint_format, new_config.lora_checkpoint_format) new_config.timestep_sampler = safe_pop(self.timestep_sampler, new_config.timestep_sampler) new_config.discrete_flow_shift = safe_pop(self.discrete_flow_shift, new_config.discrete_flow_shift) new_config.sigmoid_scale = safe_pop(self.sigmoid_scale, new_config.sigmoid_scale) new_config.lora_scale = safe_pop(self.lora_scale, new_config.lora_scale) new_config.guidance_scale = safe_pop(self.guidance_scale, new_config.guidance_scale) new_config.use_masks = safe_pop(self.use_masks, new_config.use_masks) new_config.prediction_type = safe_pop(self.prediction_type, new_config.prediction_type) new_config.max_checkpoints = safe_pop(self.max_checkpoints, new_config.max_checkpoints) # Preserve the target modules from the original config # These are not UI components but need to be preserved if hasattr(orig_config, "flux_lora_target_modules") and orig_config.flux_lora_target_modules: new_config.flux_lora_target_modules = orig_config.flux_lora_target_modules if ( hasattr(orig_config, "text_encoder_lora_target_modules") and orig_config.text_encoder_lora_target_modules ): new_config.text_encoder_lora_target_modules = orig_config.text_encoder_lora_target_modules # Handle validation prompts try: validation_prompts_text = safe_pop(self.validation_prompts, "") positive_prompts = validation_prompts_text new_config.validation_prompts = positive_prompts except Exception as e: print(f"Error processing validation prompts: {e}") new_config.num_validation_images_per_prompt = safe_pop( self.num_validation_images_per_prompt, new_config.num_validation_images_per_prompt ) # Update nested configs try: data_loader_config_group = self.image_caption_sd_data_loader_config_group # Handle the case where data_loader might be None new_config.data_loader = data_loader_config_group.update_config_with_ui_component_data( new_config.data_loader, ui_data_copy ) except Exception as e: print(f"Error updating data loader config: {e}") try: base_pipeline_group = self.base_pipeline_config_group new_config = base_pipeline_group.update_config_with_ui_component_data(new_config, ui_data_copy) except Exception as e: print(f"Error updating base pipeline config: {e}") try: # Handle the case where optimizer might be None if new_config.optimizer is None: from invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig new_config.optimizer = AdamOptimizerConfig() new_config.optimizer = self.optimizer_config_group.update_config_with_ui_component_data( new_config.optimizer, ui_data_copy ) except Exception as e: print(f"Error updating optimizer config: {e}") # We're more lenient with the assertion now if len(ui_data_copy) > 0: print(f"Warning: {len(ui_data_copy)} UI components were not transferred to the config") return new_config except Exception as e: print(f"Error in update_config_with_ui_component_data: {e}") # Return the original config to avoid errors return orig_config ================================================ FILE: src/invoke_training/ui/config_groups/image_caption_sd_data_loader_config_group.py ================================================ from typing import Any import gradio as gr from invoke_training.config.data.data_loader_config import ImageCaptionSDDataLoaderConfig from invoke_training.ui.config_groups.aspect_ratio_bucket_config_group import AspectRatioBucketConfigGroup from invoke_training.ui.config_groups.dataset_config_group import DatasetConfigGroup from invoke_training.ui.config_groups.ui_config_element import UIConfigElement class ImageCaptionSDDataLoaderConfigGroup(UIConfigElement): def __init__(self): with gr.Tab("Data Source Configs"): with gr.Row(): with gr.Column(scale=1): with gr.Group(): self.dataset = DatasetConfigGroup( allowed_types=[ "HF_HUB_IMAGE_CAPTION_DATASET", "IMAGE_CAPTION_JSONL_DATASET", "IMAGE_CAPTION_DIR_DATASET", ] ) with gr.Column(scale=3): with gr.Tab("Data Loading Configs"): with gr.Group(): with gr.Row(): self.resolution = gr.Number( label="Resolution", info="The resolution for input images. All of the images in the dataset will be" " resized to this resolution unless the aspect_ratio_buckets config is set.", precision=0, interactive=True, ) self.dataloader_num_workers = gr.Number( label="Dataloading Workers", info="Number of subprocesses to use for data loading. 0 means that the data will" " be loaded in the main process.", precision=0, interactive=True, ) with gr.Row(): self.center_crop = gr.Checkbox( label="Center Crop", info="If set, input images will be center-cropped to the target resolution." " Otherwise, input images will be randomly cropped to the target resolution.", interactive=True, ) self.random_flip = gr.Checkbox( label="Random Flip", info="If set, random flip augmentations will be applied to input images.", interactive=True, ) self.caption_prefix = gr.Textbox( label="Caption Prefix", info="A prefix that will be prepended to all captions." " If None, no prefix will be added.", interactive=True, ) with gr.Tab("Aspect Ratio Bucketing Configs"): self.aspect_ratio_bucket_config_group = AspectRatioBucketConfigGroup() def update_ui_components_with_config_data( self, config: ImageCaptionSDDataLoaderConfig ) -> dict[gr.components.Component, Any]: update_dict = { self.resolution: config.resolution, self.center_crop: config.center_crop, self.random_flip: config.random_flip, self.caption_prefix: config.caption_prefix, self.dataloader_num_workers: config.dataloader_num_workers, } update_dict.update(self.dataset.update_ui_components_with_config_data(config.dataset)) update_dict.update( self.aspect_ratio_bucket_config_group.update_ui_components_with_config_data(config.aspect_ratio_buckets) ) return update_dict def update_config_with_ui_component_data( self, orig_config: ImageCaptionSDDataLoaderConfig, ui_data: dict[gr.components.Component, Any], ) -> ImageCaptionSDDataLoaderConfig: # Handle the case where orig_config is None if orig_config is None: from invoke_training.config.data.data_loader_config import ( AspectRatioBucketConfig, ImageCaptionSDDataLoaderConfig, ) from invoke_training.config.data.dataset_config import ImageCaptionJsonlDatasetConfig # Create a default config orig_config = ImageCaptionSDDataLoaderConfig( type="IMAGE_CAPTION_SD_DATA_LOADER", dataset=ImageCaptionJsonlDatasetConfig(jsonl_path=""), aspect_ratio_buckets=AspectRatioBucketConfig(), resolution=512, center_crop=False, random_flip=True, caption_prefix=None, dataloader_num_workers=4, ) new_config = orig_config.model_copy(deep=True) new_config.dataset = self.dataset.update_config_with_ui_component_data(orig_config.dataset, ui_data) new_config.aspect_ratio_buckets = self.aspect_ratio_bucket_config_group.update_config_with_ui_component_data( orig_config.aspect_ratio_buckets, ui_data ) new_config.resolution = ui_data.pop(self.resolution) new_config.center_crop = ui_data.pop(self.center_crop) new_config.random_flip = ui_data.pop(self.random_flip) new_config.caption_prefix = ui_data.pop(self.caption_prefix) or None new_config.dataloader_num_workers = ui_data.pop(self.dataloader_num_workers) return new_config ================================================ FILE: src/invoke_training/ui/config_groups/optimizer_config_group.py ================================================ from typing import Any import gradio as gr from invoke_training.config.optimizer.optimizer_config import AdamOptimizerConfig, ProdigyOptimizerConfig from invoke_training.ui.config_groups.ui_config_element import UIConfigElement OptimizerConfig = AdamOptimizerConfig | ProdigyOptimizerConfig class AdamOptimizerConfigGroup(UIConfigElement): def __init__(self): with gr.Tab("Core"): with gr.Row(): self.learning_rate = gr.Number( label="Learning Rate", info="Initial learning rate to use (after the potential warmup period). Note that in some training " "pipelines this can be overriden for a specific group of params.", interactive=True, ) self.use_8bit = gr.Checkbox( label="Use 8-bit", info="Use 8-bit Adam optimizer to reduce VRAM requirements. (Requires bitsandbytes.)", interactive=True, ) with gr.Tab("Advanced"): with gr.Row(): self.beta1 = gr.Number(label="beta1", interactive=True) self.beta2 = gr.Number(label="beta2", interactive=True) with gr.Row(): self.weight_decay = gr.Number(label="Weight Decay", interactive=True) self.epsilon = gr.Number(label="epsilon", interactive=True) def update_ui_components_with_config_data(self, config: AdamOptimizerConfig) -> dict[gr.components.Component, Any]: return { self.learning_rate: config.learning_rate, self.beta1: config.beta1, self.beta2: config.beta2, self.weight_decay: config.weight_decay, self.epsilon: config.epsilon, self.use_8bit: config.use_8bit, } def update_config_with_ui_component_data( self, orig_config: AdamOptimizerConfig | None, ui_data: dict ) -> OptimizerConfig: assert orig_config is None return AdamOptimizerConfig( learning_rate=ui_data.pop(self.learning_rate), beta1=ui_data.pop(self.beta1), beta2=ui_data.pop(self.beta2), weight_decay=ui_data.pop(self.weight_decay), epsilon=ui_data.pop(self.epsilon), use_8bit=ui_data.pop(self.use_8bit), ) class ProdigyOptimizerConfigGroup(UIConfigElement): def __init__(self): with gr.Tab("Core"): with gr.Row(): self.learning_rate = gr.Number( label="Learning Rate", info="The learning rate. For the Prodigy optimizer, the learning rate is adjusted dynamically. A " "value of 1.0 is recommended. Note that in some pipelines this can be overriden for specific " "groups of parameters.", interactive=True, ) with gr.Tab("Advanced"): with gr.Row(): self.weight_decay = gr.Number(label="Weight Decay", interactive=True) with gr.Row(): self.use_bias_correction = gr.Checkbox(label="Bias Correction", interactive=True) self.safeguard_warmup = gr.Checkbox(label="Safeguard Warmup", interactive=True) def update_ui_components_with_config_data( self, config: ProdigyOptimizerConfig ) -> dict[gr.components.Component, Any]: return { self.learning_rate: config.learning_rate, self.weight_decay: config.weight_decay, self.use_bias_correction: config.use_bias_correction, self.safeguard_warmup: config.safeguard_warmup, } def update_config_with_ui_component_data( self, orig_config: ProdigyOptimizerConfig | None, ui_data: dict ) -> OptimizerConfig: assert orig_config is None return ProdigyOptimizerConfig( learning_rate=ui_data.pop(self.learning_rate), weight_decay=ui_data.pop(self.weight_decay), use_bias_correction=ui_data.pop(self.use_bias_correction), safeguard_warmup=ui_data.pop(self.safeguard_warmup), ) class OptimizerConfigGroup(UIConfigElement): def __init__(self): with gr.Group(): self.optimizer_type = gr.Dropdown(label="optimizer", choices=["AdamW", "Prodigy"], interactive=True) with gr.Group() as adam_optimizer_config_group: self.adam_optimizer_config = AdamOptimizerConfigGroup() self.adam_optimizer_config_group = adam_optimizer_config_group with gr.Group() as prodigy_optimizer_config_group: self.prodigy_optimizer_config = ProdigyOptimizerConfigGroup() self.prodigy_optimizer_config_group = prodigy_optimizer_config_group self.optimizer_type.change( self._on_optimizer_type_change, inputs=[self.optimizer_type], outputs=[self.adam_optimizer_config_group, self.prodigy_optimizer_config_group], ) def _on_optimizer_type_change(self, optimizer_type: str): return { self.adam_optimizer_config_group: gr.Group(visible=optimizer_type == "AdamW"), self.prodigy_optimizer_config_group: gr.Group(visible=optimizer_type == "Prodigy"), } def update_ui_components_with_config_data(self, config: OptimizerConfig) -> dict[gr.components.Component, Any]: update_dict = { self.optimizer_type: config.optimizer_type, self.adam_optimizer_config_group: gr.Group(visible=config.optimizer_type == "AdamW"), self.prodigy_optimizer_config_group: gr.Group(visible=config.optimizer_type == "Prodigy"), } update_dict.update( self.adam_optimizer_config.update_ui_components_with_config_data( config if config.optimizer_type == "AdamW" else AdamOptimizerConfig() ) ) update_dict.update( self.prodigy_optimizer_config.update_ui_components_with_config_data( config if config.optimizer_type == "Prodigy" else ProdigyOptimizerConfig() ) ) return update_dict def update_config_with_ui_component_data(self, orig_config: OptimizerConfig, ui_data: dict) -> OptimizerConfig: # TODO: Use orig_config? new_config_adam = self.adam_optimizer_config.update_config_with_ui_component_data(None, ui_data) new_config_prodigy = self.prodigy_optimizer_config.update_config_with_ui_component_data(None, ui_data) optimizer_type = ui_data.pop(self.optimizer_type) if optimizer_type == "AdamW": return new_config_adam elif optimizer_type == "Prodigy": return new_config_prodigy else: raise ValueError(f"Invalid optimizer type: {optimizer_type}") ================================================ FILE: src/invoke_training/ui/config_groups/sd_lora_config_group.py ================================================ import typing import gradio as gr from invoke_training.pipelines.stable_diffusion.lora.config import SdLoraConfig from invoke_training.ui.config_groups.base_pipeline_config_group import BasePipelineConfigGroup from invoke_training.ui.config_groups.image_caption_sd_data_loader_config_group import ( ImageCaptionSDDataLoaderConfigGroup, ) from invoke_training.ui.config_groups.optimizer_config_group import OptimizerConfigGroup from invoke_training.ui.config_groups.ui_config_element import UIConfigElement from invoke_training.ui.utils.prompts import ( convert_pos_neg_prompts_to_ui_prompts, convert_ui_prompts_to_pos_neg_prompts, ) from invoke_training.ui.utils.utils import get_typing_literal_options class SdLoraConfigGroup(UIConfigElement): def __init__(self): """The SD_LORA configs.""" gr.Markdown("## Basic Configs") with gr.Row(): with gr.Column(scale=1): with gr.Tab("Base Model"): self.model = gr.Textbox( label="Model", info="The base model. Can be a Hugging Face Hub model name, or a path to a local model (in " "diffusers or checkpoint format).", type="text", interactive=True, ) self.hf_variant = gr.Textbox( label="Variant", info="(optional) The Hugging Face hub model variant (e.g., fp16, fp32) to use if the model is a" " HF Hub model name.", type="text", interactive=True, ) with gr.Column(scale=3): with gr.Tab("Training Outputs"): self.base_pipeline_config_group = BasePipelineConfigGroup() self.max_checkpoints = gr.Number( label="Maximum Number of Checkpoints", info="The maximum number of checkpoints to keep on disk from this training run. Earlier " "checkpoints will be deleted to respect this limit.", interactive=True, precision=0, ) gr.Markdown("## Data Configs") self.image_caption_sd_data_loader_config_group = ImageCaptionSDDataLoaderConfigGroup() gr.Markdown("## Optimizer Configs") self.optimizer_config_group = OptimizerConfigGroup() gr.Markdown("## Speed / Memory Configs") with gr.Group(): with gr.Row(): self.gradient_accumulation_steps = gr.Number( label="Gradient Accumulation Steps", info="The number of gradient steps to accumulate before each weight update. This is an" " alternative to increasing the batch size when training with limited VRAM. " "effective_batch_size = train_batch_size * gradient_accumulation_steps.", precision=0, interactive=True, ) with gr.Row(): self.weight_dtype = gr.Dropdown( label="Weight Type", info="The precision of the model weights. Lower precision can speed up training and reduce memory, " "with increased risk of numerical stability issues. 'bfloat16' is recommended for most use cases " "if your GPU supports it.", choices=get_typing_literal_options(SdLoraConfig, "weight_dtype"), interactive=True, ) with gr.Row(): self.cache_text_encoder_outputs = gr.Checkbox( label="Cache Text Encoder Outputs", info="Cache the text encoder outputs to increase speed. This should not be used when training the " "text encoder or performing data augmentations that would change the text encoder outputs.", interactive=True, ) self.cache_vae_outputs = gr.Checkbox( label="Cache VAE Outputs", info="Cache the VAE outputs to increase speed. This should not be used when training the UNet or " "performing data augmentations that would change the VAE outputs.", interactive=True, ) with gr.Row(): self.enable_cpu_offload_during_validation = gr.Checkbox( label="Enable CPU Offload during Validation", info="Offload models to the CPU sequentially during validation. This reduces peak VRAM " "requirements at the cost of slower validation during training.", interactive=True, ) self.gradient_checkpointing = gr.Checkbox( label="Gradient Checkpointing", info="If True, VRAM requirements are reduced at the cost of ~20% slower training", interactive=True, ) gr.Markdown("## General Training Configs") with gr.Tab("Core"): with gr.Row(): self.train_unet = gr.Checkbox(label="Train UNet", interactive=True) self.train_text_encoder = gr.Checkbox(label="Train Text Encoder", interactive=True) with gr.Row(): self.unet_learning_rate = gr.Number( label="UNet Learning Rate", info="The UNet learning rate. Set to 0 or leave empty to inherit from the base optimizer " "learning rate.", interactive=True, ) self.text_encoder_learning_rate = gr.Number( label="Text Encoder Learning Rate", info="The text encoder learning rate. Set to 0 or leave empty to inherit from the base optimizer " "learning rate.", interactive=True, ) with gr.Row(): self.lr_scheduler = gr.Dropdown( label="Learning Rate Scheduler", choices=get_typing_literal_options(SdLoraConfig, "lr_scheduler"), interactive=True, ) self.lr_warmup_steps = gr.Number( label="Warmup Steps", info="The number of warmup steps in the " "learning rate schedule, if applicable to the selected scheduler.", interactive=True, ) with gr.Row(): self.use_masks = gr.Checkbox( label="Use Masks", info="This can only be enabled if the dataset contains masks.", interactive=True ) with gr.Tab("Advanced"): with gr.Column(): self.lora_rank_dim = gr.Number( label="LoRA Rank Dim", info="The rank dimension to use for the LoRA layers. Increasing the rank dimension" " increases the model's expressivity, but also increases the size of the generated LoRA model.", interactive=True, precision=0, ) self.min_snr_gamma = gr.Number( label="Minumum SNR Gamma", info="min_snr_gamma acts like an an upper bound on the weight of samples with low noise " "levels. If None, then Min-SNR weighting will not be applied. If enabled, the recommended " "value is min_snr gamma = 5.0.", interactive=True, ) self.max_grad_norm = gr.Number( label="Max Gradient Norm", info="Max gradient norm for clipping. Set to 0 or leave empty for no clipping (null).", interactive=True, ) self.train_batch_size = gr.Number( label="Batch Size", info="The Training Batch Size - Higher values require increasing amounts of VRAM.", precision=0, interactive=True, ) gr.Markdown("## Validation") with gr.Group(): self.validation_prompts = gr.Textbox( label="Validation Prompts", info="Enter one validation prompt per line. Optionally, add negative prompts after a '[NEG]' " "delimiter. For example: `positive prompt[NEG]negative prompt`. ", lines=5, interactive=True, ) self.num_validation_images_per_prompt = gr.Number( label="# of Validation Images to Generate per Prompt", precision=0, interactive=True ) def update_ui_components_with_config_data(self, config: SdLoraConfig) -> dict[gr.components.Component, typing.Any]: update_dict = { self.model: config.model, self.hf_variant: config.hf_variant, self.max_checkpoints: config.max_checkpoints, self.train_unet: config.train_unet, self.unet_learning_rate: config.unet_learning_rate, self.train_text_encoder: config.train_text_encoder, self.text_encoder_learning_rate: config.text_encoder_learning_rate, self.lr_scheduler: config.lr_scheduler, self.lr_warmup_steps: config.lr_warmup_steps, self.use_masks: config.use_masks, self.max_grad_norm: config.max_grad_norm, self.train_batch_size: config.train_batch_size, self.cache_text_encoder_outputs: config.cache_text_encoder_outputs, self.cache_vae_outputs: config.cache_vae_outputs, self.enable_cpu_offload_during_validation: config.enable_cpu_offload_during_validation, self.gradient_accumulation_steps: config.gradient_accumulation_steps, self.weight_dtype: config.weight_dtype, self.gradient_checkpointing: config.gradient_checkpointing, self.lora_rank_dim: config.lora_rank_dim, self.min_snr_gamma: config.min_snr_gamma, self.validation_prompts: convert_pos_neg_prompts_to_ui_prompts( config.validation_prompts, config.negative_validation_prompts ), self.num_validation_images_per_prompt: config.num_validation_images_per_prompt, } update_dict.update( self.image_caption_sd_data_loader_config_group.update_ui_components_with_config_data(config.data_loader) ) update_dict.update(self.base_pipeline_config_group.update_ui_components_with_config_data(config)) update_dict.update(self.optimizer_config_group.update_ui_components_with_config_data(config.optimizer)) # Sanity check to catch if we accidentally forget to update a UI component. assert set(update_dict.keys()) == set(self.get_ui_output_components()) return update_dict def update_config_with_ui_component_data( self, orig_config: SdLoraConfig, ui_data: dict[gr.components.Component, typing.Any] ) -> SdLoraConfig: new_config = orig_config.model_copy(deep=True) new_config.model = ui_data.pop(self.model) new_config.hf_variant = ui_data.pop(self.hf_variant) or None new_config.max_checkpoints = ui_data.pop(self.max_checkpoints) new_config.train_unet = ui_data.pop(self.train_unet) unet_lr_value = ui_data.pop(self.unet_learning_rate) new_config.unet_learning_rate = None if unet_lr_value == 0 else unet_lr_value new_config.train_text_encoder = ui_data.pop(self.train_text_encoder) text_encoder_lr_value = ui_data.pop(self.text_encoder_learning_rate) new_config.text_encoder_learning_rate = None if text_encoder_lr_value == 0 else text_encoder_lr_value new_config.lr_scheduler = ui_data.pop(self.lr_scheduler) new_config.lr_warmup_steps = ui_data.pop(self.lr_warmup_steps) new_config.use_masks = ui_data.pop(self.use_masks) max_grad_norm_value = ui_data.pop(self.max_grad_norm) new_config.max_grad_norm = None if max_grad_norm_value == 0 else max_grad_norm_value new_config.train_batch_size = ui_data.pop(self.train_batch_size) new_config.cache_text_encoder_outputs = ui_data.pop(self.cache_text_encoder_outputs) new_config.cache_vae_outputs = ui_data.pop(self.cache_vae_outputs) new_config.enable_cpu_offload_during_validation = ui_data.pop(self.enable_cpu_offload_during_validation) new_config.gradient_accumulation_steps = ui_data.pop(self.gradient_accumulation_steps) new_config.weight_dtype = ui_data.pop(self.weight_dtype) new_config.gradient_checkpointing = ui_data.pop(self.gradient_checkpointing) new_config.lora_rank_dim = ui_data.pop(self.lora_rank_dim) new_config.min_snr_gamma = ui_data.pop(self.min_snr_gamma) new_config.num_validation_images_per_prompt = ui_data.pop(self.num_validation_images_per_prompt) positive_prompts, negative_prompts = convert_ui_prompts_to_pos_neg_prompts(ui_data.pop(self.validation_prompts)) new_config.validation_prompts = positive_prompts new_config.negative_validation_prompts = negative_prompts new_config.data_loader = self.image_caption_sd_data_loader_config_group.update_config_with_ui_component_data( new_config.data_loader, ui_data ) new_config = self.base_pipeline_config_group.update_config_with_ui_component_data(new_config, ui_data) new_config.optimizer = self.optimizer_config_group.update_config_with_ui_component_data( new_config.optimizer, ui_data ) # We pop items from ui_data as we use them so that we can sanity check that all the input data was transferred # to the config. assert len(ui_data) == 0 return new_config ================================================ FILE: src/invoke_training/ui/config_groups/sd_textual_inversion_config_group.py ================================================ import typing import gradio as gr from invoke_training.pipelines.stable_diffusion.textual_inversion.config import SdTextualInversionConfig from invoke_training.ui.config_groups.base_pipeline_config_group import BasePipelineConfigGroup from invoke_training.ui.config_groups.optimizer_config_group import OptimizerConfigGroup from invoke_training.ui.config_groups.textual_inversion_sd_data_loader_config_group import ( TextualInversionSDDataLoaderConfigGroup, ) from invoke_training.ui.config_groups.ui_config_element import UIConfigElement from invoke_training.ui.utils.prompts import ( convert_pos_neg_prompts_to_ui_prompts, convert_ui_prompts_to_pos_neg_prompts, ) from invoke_training.ui.utils.utils import get_typing_literal_options class SdTextualInversionConfigGroup(UIConfigElement): def __init__(self): """The SD_TEXTUAL_INVERSION configs.""" gr.Markdown("## Basic Configs") with gr.Row(): with gr.Column(scale=1): with gr.Tab("Base Model"): self.model = gr.Textbox( label="Model", info="The base model. Can be a Hugging Face Hub model name, or a path to a local model (in " "diffusers or checkpoint format).", type="text", interactive=True, ) self.hf_variant = gr.Textbox( label="Variant", info="(optional) The Hugging Face hub model variant (e.g., fp16, fp32) to use if the model is a" " HF Hub model name.", type="text", interactive=True, ) with gr.Column(scale=3): with gr.Tab("Training Outputs"): self.base_pipeline_config_group = BasePipelineConfigGroup() self.max_checkpoints = gr.Number( label="Maximum Number of Checkpoints", info="The maximum number of checkpoints to keep on disk from this training run. Earlier " "checkpoints will be deleted to respect this limit.", interactive=True, precision=0, ) gr.Markdown("## Data Configs") self.textual_inversion_sd_data_loader_config_group = TextualInversionSDDataLoaderConfigGroup() gr.Markdown("## Textual Inversion Configs") self.num_vectors = gr.Number( label="Num Vectors", info="The number of TI vectors that will be trained. Can be overriden by 'Initial Phrase'.", interactive=True, precision=0, ) self.placeholder_token = gr.Textbox( label="Placeholder Token", info="The special word to associate the learned embeddings with. Choose a unique token that is unlikely to " "already exist in the tokenizer's vocabulary.", interactive=True, ) self.initializer_token = gr.Textbox( label="Initializer Token", info="Only one of 'Initializer Token' or 'Initial Phrase' should be set. A vocabulary token to use as an " "initializer for the placeholder token. It should be a single word that roughly describes the object or " "style that you're trying to train on. The initializer token ust map to a single tokenizer token.", interactive=True, ) self.initial_phrase = gr.Textbox( label="Initial Phrase", info="Only one of 'Initializer Token' or 'Initial Phrase' should be set. A phrase that will be used to " "initialize the placeholder token embedding. The phrase will be tokenized, and the corresponding " "embeddings will be used to initialize the placeholder tokens. The number of embedding vectors will be " "inferred from the length of the tokenized phrase, so keep the phrase short.", interactive=True, ) gr.Markdown("## Optimizer Configs") self.optimizer_config_group = OptimizerConfigGroup() gr.Markdown("## Speed / Memory Configs") with gr.Group(): with gr.Row(): self.gradient_accumulation_steps = gr.Number( label="Gradient Accumulation Steps", info="The number of gradient steps to accumulate before each weight update. This is an" " alternative to increasing the batch size when training with limited VRAM." "effective_batch_size = train_batch_size * gradient_accumulation_steps.", precision=0, interactive=True, ) with gr.Row(): self.weight_dtype = gr.Dropdown( label="Weight Type", info="The precision of the model weights. Lower precision can speed up training and reduce memory, " "with increased risk of numerical stability issues. 'bfloat16' is recommended for most use cases " "if your GPU supports it.", choices=get_typing_literal_options(SdTextualInversionConfig, "weight_dtype"), interactive=True, ) with gr.Row(): self.cache_vae_outputs = gr.Checkbox( label="Cache VAE Outputs", info="Cache the VAE outputs to increase speed. This should not be used when training the UNet or " "performing data augmentations that would change the VAE outputs.", interactive=True, ) with gr.Row(): self.enable_cpu_offload_during_validation = gr.Checkbox( label="Enable CPU Offload during Validation", info="Offload models to the CPU sequentially during validation. This reduces peak VRAM " "requirements at the cost of slower validation during training.", interactive=True, ) self.gradient_checkpointing = gr.Checkbox( label="Gradient Checkpointing", info="If True, VRAM requirements are reduced at the cost of ~20% slower training", interactive=True, ) gr.Markdown("## General Training Configs") with gr.Tab("Core"): with gr.Row(): self.lr_scheduler = gr.Dropdown( label="Learning Rate Scheduler", choices=get_typing_literal_options(SdTextualInversionConfig, "lr_scheduler"), interactive=True, ) self.lr_warmup_steps = gr.Number( label="Warmup Steps", info="The number of warmup steps in the " "learning rate schedule, if applicable to the selected scheduler.", interactive=True, ) with gr.Row(): self.use_masks = gr.Checkbox( label="Use Masks", info="This can only be enabled if the dataset contains masks.", interactive=True ) with gr.Tab("Advanced"): with gr.Column(): self.min_snr_gamma = gr.Number( label="Minumum SNR Gamma", info="min_snr_gamma acts like an an upper bound on the weight of samples with low noise " "levels. If None, then Min-SNR weighting will not be applied. If enabled, the recommended " "value is min_snr gamma = 5.0.", interactive=True, ) self.max_grad_norm = gr.Number( label="Max Gradient Norm", info="Max gradient norm for clipping. Set to 0 or leave empty for no clipping (null).", interactive=True, ) self.train_batch_size = gr.Number( label="Batch Size", info="The Training Batch Size - Higher values require increasing amounts of VRAM.", precision=0, interactive=True, ) gr.Markdown("## Validation") with gr.Group(): self.validation_prompts = gr.Textbox( label="Validation Prompts", info="Enter one validation prompt per line. Optionally, add negative prompts after a '[NEG]' " "delimiter. For example: `positive prompt[NEG]negative prompt`. ", lines=5, interactive=True, ) self.num_validation_images_per_prompt = gr.Number( label="# of Validation Images to Generate per Prompt", precision=0, interactive=True ) def update_ui_components_with_config_data( self, config: SdTextualInversionConfig ) -> dict[gr.components.Component, typing.Any]: update_dict = { self.model: config.model, self.hf_variant: config.hf_variant, self.num_vectors: config.num_vectors, self.placeholder_token: config.placeholder_token, self.initializer_token: config.initializer_token, self.initial_phrase: config.initial_phrase, self.max_checkpoints: config.max_checkpoints, self.lr_scheduler: config.lr_scheduler, self.lr_warmup_steps: config.lr_warmup_steps, self.use_masks: config.use_masks, self.max_grad_norm: config.max_grad_norm, self.train_batch_size: config.train_batch_size, self.cache_vae_outputs: config.cache_vae_outputs, self.enable_cpu_offload_during_validation: config.enable_cpu_offload_during_validation, self.gradient_accumulation_steps: config.gradient_accumulation_steps, self.weight_dtype: config.weight_dtype, self.gradient_checkpointing: config.gradient_checkpointing, self.min_snr_gamma: config.min_snr_gamma, self.validation_prompts: convert_pos_neg_prompts_to_ui_prompts( config.validation_prompts, config.negative_validation_prompts ), self.num_validation_images_per_prompt: config.num_validation_images_per_prompt, } update_dict.update( self.textual_inversion_sd_data_loader_config_group.update_ui_components_with_config_data(config.data_loader) ) update_dict.update(self.base_pipeline_config_group.update_ui_components_with_config_data(config)) update_dict.update(self.optimizer_config_group.update_ui_components_with_config_data(config.optimizer)) # Sanity check to catch if we accidentally forget to update a UI component. assert set(update_dict.keys()) == set(self.get_ui_output_components()) return update_dict def update_config_with_ui_component_data( self, orig_config: SdTextualInversionConfig, ui_data: dict[gr.components.Component, typing.Any] ) -> SdTextualInversionConfig: new_config = orig_config.model_copy(deep=True) new_config.model = ui_data.pop(self.model) new_config.hf_variant = ui_data.pop(self.hf_variant) or None new_config.num_vectors = ui_data.pop(self.num_vectors) new_config.placeholder_token = ui_data.pop(self.placeholder_token) new_config.initializer_token = ui_data.pop(self.initializer_token) or None new_config.initial_phrase = ui_data.pop(self.initial_phrase) or None new_config.max_checkpoints = ui_data.pop(self.max_checkpoints) new_config.lr_scheduler = ui_data.pop(self.lr_scheduler) new_config.lr_warmup_steps = ui_data.pop(self.lr_warmup_steps) new_config.use_masks = ui_data.pop(self.use_masks) max_grad_norm_value = ui_data.pop(self.max_grad_norm) new_config.max_grad_norm = None if max_grad_norm_value == 0 else max_grad_norm_value new_config.train_batch_size = ui_data.pop(self.train_batch_size) new_config.cache_vae_outputs = ui_data.pop(self.cache_vae_outputs) new_config.enable_cpu_offload_during_validation = ui_data.pop(self.enable_cpu_offload_during_validation) new_config.gradient_accumulation_steps = ui_data.pop(self.gradient_accumulation_steps) new_config.weight_dtype = ui_data.pop(self.weight_dtype) new_config.gradient_checkpointing = ui_data.pop(self.gradient_checkpointing) new_config.min_snr_gamma = ui_data.pop(self.min_snr_gamma) new_config.num_validation_images_per_prompt = ui_data.pop(self.num_validation_images_per_prompt) positive_prompts, negative_prompts = convert_ui_prompts_to_pos_neg_prompts(ui_data.pop(self.validation_prompts)) new_config.validation_prompts = positive_prompts new_config.negative_validation_prompts = negative_prompts new_config.data_loader = ( self.textual_inversion_sd_data_loader_config_group.update_config_with_ui_component_data( new_config.data_loader, ui_data ) ) new_config = self.base_pipeline_config_group.update_config_with_ui_component_data(new_config, ui_data) new_config.optimizer = self.optimizer_config_group.update_config_with_ui_component_data( new_config.optimizer, ui_data ) # We pop items from ui_data as we use them so that we can sanity check that all the input data was transferred # to the config. assert len(ui_data) == 0 return new_config ================================================ FILE: src/invoke_training/ui/config_groups/sdxl_finetune_config_group.py ================================================ import typing import gradio as gr from invoke_training.pipelines.stable_diffusion_xl.finetune.config import SdxlFinetuneConfig from invoke_training.ui.config_groups.base_pipeline_config_group import BasePipelineConfigGroup from invoke_training.ui.config_groups.image_caption_sd_data_loader_config_group import ( ImageCaptionSDDataLoaderConfigGroup, ) from invoke_training.ui.config_groups.optimizer_config_group import OptimizerConfigGroup from invoke_training.ui.config_groups.ui_config_element import UIConfigElement from invoke_training.ui.utils.prompts import ( convert_pos_neg_prompts_to_ui_prompts, convert_ui_prompts_to_pos_neg_prompts, ) from invoke_training.ui.utils.utils import get_typing_literal_options class SdxlFinetuneConfigGroup(UIConfigElement): def __init__(self): """The SDXL_FINETUNE configs.""" gr.Markdown("## Basic Configs") with gr.Row(): with gr.Column(scale=1): with gr.Tab("Base Model"): self.model = gr.Textbox( label="Model", info="The base model. Can be a Hugging Face Hub model name, or a path to a local model (in " "diffusers or checkpoint format).", type="text", interactive=True, ) self.hf_variant = gr.Textbox( label="Variant", info="(optional) The Hugging Face hub model variant (e.g., fp16, fp32) to use if the model is a" " HF Hub model name.", type="text", interactive=True, ) self.vae_model = gr.Textbox( label="VAE Model", info="(optional) If set, this overrides the base model's default VAE model.", type="text", interactive=True, ) with gr.Column(scale=3): with gr.Tab("Training Outputs"): self.base_pipeline_config_group = BasePipelineConfigGroup() self.save_checkpoint_format = gr.Dropdown( label="Checkpoint Format", info="The save format for the checkpoints. `full_diffusers` saves the full model in diffusers " "format. `trained_only_diffusers` saves only the parts of the model that were finetuned " "(i.e. the UNet).", choices=get_typing_literal_options(SdxlFinetuneConfig, "save_checkpoint_format"), interactive=True, ) self.save_dtype = gr.Dropdown( label="Save Dtype", info="The dtype to use when saving the model.", choices=get_typing_literal_options(SdxlFinetuneConfig, "save_dtype"), interactive=True, ) self.max_checkpoints = gr.Number( label="Maximum Number of Checkpoints", info="The maximum number of checkpoints to keep on disk from this training run. Earlier " "checkpoints will be deleted to respect this limit.", interactive=True, precision=0, ) gr.Markdown("## Data Configs") self.image_caption_sd_data_loader_config_group = ImageCaptionSDDataLoaderConfigGroup() gr.Markdown("## Optimizer Configs") self.optimizer_config_group = OptimizerConfigGroup() gr.Markdown("## Speed / Memory Configs") with gr.Group(): with gr.Row(): self.gradient_accumulation_steps = gr.Number( label="Gradient Accumulation Steps", info="The number of gradient steps to accumulate before each weight update. This is an alternative" "to increasing the batch size when training with limited VRAM." "effective_batch_size = train_batch_size * gradient_accumulation_steps.", precision=0, interactive=True, ) with gr.Row(): self.weight_dtype = gr.Dropdown( label="Weight Type", info="The precision of the model weights. Lower precision can speed up training and reduce memory, " "with increased risk of numerical stability issues. 'bfloat16' is recommended for most use cases " "if your GPU supports it.", choices=get_typing_literal_options(SdxlFinetuneConfig, "weight_dtype"), interactive=True, ) with gr.Row(): self.cache_text_encoder_outputs = gr.Checkbox( label="Cache Text Encoder Outputs", info="Cache the text encoder outputs to increase speed. This should not be used when training the " "text encoder or performing data augmentations that would change the text encoder outputs.", interactive=True, ) self.cache_vae_outputs = gr.Checkbox( label="Cache VAE Outputs", info="Cache the VAE outputs to increase speed. This should not be used when training the UNet or " "performing data augmentations that would change the VAE outputs.", interactive=True, ) with gr.Row(): self.enable_cpu_offload_during_validation = gr.Checkbox( label="Enable CPU Offload during Validation", info="Offload models to the CPU sequentially during validation. This reduces peak VRAM " "requirements at the cost of slower validation during training.", interactive=True, ) self.gradient_checkpointing = gr.Checkbox( label="Gradient Checkpointing", info="If True, VRAM requirements are reduced at the cost of ~20% slower training", interactive=True, ) gr.Markdown("## General Training Configs") with gr.Tab("Core"): with gr.Row(): self.lr_scheduler = gr.Dropdown( label="Learning Rate Scheduler", choices=get_typing_literal_options(SdxlFinetuneConfig, "lr_scheduler"), interactive=True, ) self.lr_warmup_steps = gr.Number( label="Warmup Steps", info="The number of warmup steps in the " "learning rate schedule, if applicable to the selected scheduler.", interactive=True, ) with gr.Row(): self.use_masks = gr.Checkbox( label="Use Masks", info="This can only be enabled if the dataset contains masks.", interactive=True ) with gr.Tab("Advanced"): with gr.Row(): self.min_snr_gamma = gr.Number( label="Minimum SNR Gamma", info="min_snr_gamma acts like an an upper bound on the weight of samples with low noise " "levels. If None, then Min-SNR weighting will not be applied. If enabled, the recommended " "value is min_snr gamma = 5.0.", interactive=True, ) self.max_grad_norm = gr.Number( label="Max Gradient Norm", info="Max gradient norm for clipping. Set to 0 or leave empty for no clipping (null).", interactive=True, ) self.train_batch_size = gr.Number( label="Batch Size", info="The Training Batch Size - Higher values require increasing amounts of VRAM.", precision=0, interactive=True, ) gr.Markdown("## Validation") with gr.Group(): self.validation_prompts = gr.Textbox( label="Validation Prompts", info="Enter one validation prompt per line. Optionally, add negative prompts after a '[NEG]' " "delimiter. For example: `positive prompt[NEG]negative prompt`. ", lines=5, interactive=True, ) self.num_validation_images_per_prompt = gr.Number( label="# of Validation Images to Generate per Prompt", precision=0, interactive=True ) def update_ui_components_with_config_data( self, config: SdxlFinetuneConfig ) -> dict[gr.components.Component, typing.Any]: update_dict = { self.model: config.model, self.hf_variant: config.hf_variant, self.vae_model: config.vae_model, self.save_checkpoint_format: config.save_checkpoint_format, self.save_dtype: config.save_dtype, self.max_checkpoints: config.max_checkpoints, self.lr_scheduler: config.lr_scheduler, self.lr_warmup_steps: config.lr_warmup_steps, self.use_masks: config.use_masks, self.min_snr_gamma: config.min_snr_gamma, self.max_grad_norm: config.max_grad_norm, self.train_batch_size: config.train_batch_size, self.cache_text_encoder_outputs: config.cache_text_encoder_outputs, self.cache_vae_outputs: config.cache_vae_outputs, self.enable_cpu_offload_during_validation: config.enable_cpu_offload_during_validation, self.gradient_accumulation_steps: config.gradient_accumulation_steps, self.weight_dtype: config.weight_dtype, self.gradient_checkpointing: config.gradient_checkpointing, self.validation_prompts: convert_pos_neg_prompts_to_ui_prompts( config.validation_prompts, config.negative_validation_prompts ), self.num_validation_images_per_prompt: config.num_validation_images_per_prompt, } update_dict.update( self.image_caption_sd_data_loader_config_group.update_ui_components_with_config_data(config.data_loader) ) update_dict.update(self.base_pipeline_config_group.update_ui_components_with_config_data(config)) update_dict.update(self.optimizer_config_group.update_ui_components_with_config_data(config.optimizer)) # Sanity check to catch if we accidentally forget to update a UI component. assert set(update_dict.keys()) == set(self.get_ui_output_components()) return update_dict def update_config_with_ui_component_data( self, orig_config: SdxlFinetuneConfig, ui_data: dict[gr.components.Component, typing.Any] ) -> SdxlFinetuneConfig: new_config = orig_config.model_copy(deep=True) new_config.model = ui_data.pop(self.model) new_config.hf_variant = ui_data.pop(self.hf_variant) or None new_config.vae_model = ui_data.pop(self.vae_model) or None new_config.save_checkpoint_format = ui_data.pop(self.save_checkpoint_format) new_config.save_dtype = ui_data.pop(self.save_dtype) new_config.max_checkpoints = ui_data.pop(self.max_checkpoints) new_config.lr_scheduler = ui_data.pop(self.lr_scheduler) new_config.lr_warmup_steps = ui_data.pop(self.lr_warmup_steps) new_config.use_masks = ui_data.pop(self.use_masks) new_config.min_snr_gamma = ui_data.pop(self.min_snr_gamma) max_grad_norm_value = ui_data.pop(self.max_grad_norm) new_config.max_grad_norm = None if max_grad_norm_value == 0 else max_grad_norm_value new_config.train_batch_size = ui_data.pop(self.train_batch_size) new_config.cache_text_encoder_outputs = ui_data.pop(self.cache_text_encoder_outputs) new_config.cache_vae_outputs = ui_data.pop(self.cache_vae_outputs) new_config.enable_cpu_offload_during_validation = ui_data.pop(self.enable_cpu_offload_during_validation) new_config.gradient_accumulation_steps = ui_data.pop(self.gradient_accumulation_steps) new_config.weight_dtype = ui_data.pop(self.weight_dtype) new_config.gradient_checkpointing = ui_data.pop(self.gradient_checkpointing) new_config.num_validation_images_per_prompt = ui_data.pop(self.num_validation_images_per_prompt) positive_prompts, negative_prompts = convert_ui_prompts_to_pos_neg_prompts(ui_data.pop(self.validation_prompts)) new_config.validation_prompts = positive_prompts new_config.negative_validation_prompts = negative_prompts new_config.data_loader = self.image_caption_sd_data_loader_config_group.update_config_with_ui_component_data( new_config.data_loader, ui_data ) new_config = self.base_pipeline_config_group.update_config_with_ui_component_data(new_config, ui_data) new_config.optimizer = self.optimizer_config_group.update_config_with_ui_component_data( new_config.optimizer, ui_data ) # We pop items from ui_data as we use them so that we can sanity check that all the input data was transferred # to the config. assert len(ui_data) == 0 return new_config ================================================ FILE: src/invoke_training/ui/config_groups/sdxl_lora_and_textual_inversion_config_group.py ================================================ import typing import gradio as gr from invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config import ( SdxlLoraAndTextualInversionConfig, ) from invoke_training.ui.config_groups.base_pipeline_config_group import BasePipelineConfigGroup from invoke_training.ui.config_groups.optimizer_config_group import OptimizerConfigGroup from invoke_training.ui.config_groups.textual_inversion_sd_data_loader_config_group import ( TextualInversionSDDataLoaderConfigGroup, ) from invoke_training.ui.config_groups.ui_config_element import UIConfigElement from invoke_training.ui.utils.prompts import ( convert_pos_neg_prompts_to_ui_prompts, convert_ui_prompts_to_pos_neg_prompts, ) from invoke_training.ui.utils.utils import get_typing_literal_options class SdxlLoraAndTextualInversionConfigGroup(UIConfigElement): def __init__(self): """The SDXL_LORA_AND_TEXTUAL_INVERSION configs.""" gr.Markdown("## Basic Configs") with gr.Row(): with gr.Column(scale=1): with gr.Tab("Base Model"): self.model = gr.Textbox( label="Model", info="The base model. Can be a Hugging Face Hub model name, or a path to a local model (in " "diffusers or checkpoint format).", type="text", interactive=True, ) self.hf_variant = gr.Textbox( label="Variant", info="(optional) The Hugging Face hub model variant (e.g., fp16, fp32) to use if the model is a" " HF Hub model name.", type="text", interactive=True, ) self.vae_model = gr.Textbox( label="VAE Model", info="(optional) If set, this overrides the base model's default VAE model.", type="text", interactive=True, ) with gr.Column(scale=3): with gr.Tab("Training Outputs"): self.base_pipeline_config_group = BasePipelineConfigGroup() self.max_checkpoints = gr.Number( label="Maximum Number of Checkpoints", info="The maximum number of checkpoints to keep on disk from this training run. Earlier " "checkpoints will be deleted to respect this limit.", interactive=True, precision=0, ) gr.Markdown("## Data Configs") self.image_caption_sd_data_loader_config_group = TextualInversionSDDataLoaderConfigGroup() gr.Markdown("## Textual Inversion Configs") self.num_vectors = gr.Number( label="Num Vectors", info="The number of TI vectors that will be trained. Can be overriden by 'Initial Phrase'.", interactive=True, precision=0, ) self.placeholder_token = gr.Textbox( label="Placeholder Token", info="The special word to associate the learned embeddings with. Choose a unique token that is unlikely to " "already exist in the tokenizer's vocabulary.", interactive=True, ) self.initializer_token = gr.Textbox( label="Initializer Token", info="Only one of 'Initializer Token' or 'Initial Phrase' should be set. A vocabulary token to use as an " "initializer for the placeholder token. It should be a single word that roughly describes the object or " "style that you're trying to train on. The initializer token ust map to a single tokenizer token.", interactive=True, ) self.initial_phrase = gr.Textbox( label="Initial Phrase", info="Only one of 'Initializer Token' or 'Initial Phrase' should be set. A phrase that will be used to " "initialize the placeholder token embedding. The phrase will be tokenized, and the corresponding " "embeddings will be used to initialize the placeholder tokens. The number of embedding vectors will be " "inferred from the length of the tokenized phrase, so keep the phrase short.", interactive=True, ) gr.Markdown("## Optimizer Configs") self.optimizer_config_group = OptimizerConfigGroup() gr.Markdown("## Speed / Memory Configs") with gr.Group(): with gr.Row(): self.gradient_accumulation_steps = gr.Number( label="Gradient Accumulation Steps", info="The number of gradient steps to accumulate before each weight update. This is an alternative" "to increasing the batch size when training with limited VRAM." "effective_batch_size = train_batch_size * gradient_accumulation_steps.", precision=0, interactive=True, ) with gr.Row(): self.weight_dtype = gr.Dropdown( label="Weight Type", info="The precision of the model weights. Lower precision can speed up training and reduce memory, " "with increased risk of numerical stability issues. 'bfloat16' is recommended for most use cases " "if your GPU supports it.", choices=get_typing_literal_options(SdxlLoraAndTextualInversionConfig, "weight_dtype"), interactive=True, ) with gr.Row(): self.cache_text_encoder_outputs = gr.Checkbox( label="Cache Text Encoder Outputs", info="Cache the text encoder outputs to increase speed. This should not be used when training the " "text encoder or performing data augmentations that would change the text encoder outputs.", interactive=True, ) self.cache_vae_outputs = gr.Checkbox( label="Cache VAE Outputs", info="Cache the VAE outputs to increase speed. This should not be used when training the UNet or " "performing data augmentations that would change the VAE outputs.", interactive=True, ) with gr.Row(): self.enable_cpu_offload_during_validation = gr.Checkbox( label="Enable CPU Offload during Validation", info="Offload models to the CPU sequentially during validation. This reduces peak VRAM " "requirements at the cost of slower validation during training.", interactive=True, ) self.gradient_checkpointing = gr.Checkbox( label="Gradient Checkpointing", info="If True, VRAM requirements are reduced at the cost of ~20% slower training", interactive=True, ) gr.Markdown("## General Training Configs") with gr.Tab("Core"): with gr.Row(): self.train_unet = gr.Checkbox(label="Train UNet", interactive=True) self.train_text_encoder = gr.Checkbox(label="Train Text Encoder", interactive=True) self.train_ti = gr.Checkbox(label="Train Textual Inversion Token", scale=2, interactive=True) with gr.Row(): self.unet_learning_rate = gr.Number( label="UNet Learning Rate", info="The UNet learning rate. Set to 0 or leave empty to inherit from the base optimizer " "learning rate.", interactive=True, ) self.text_encoder_learning_rate = gr.Number( label="Text Encoder Learning Rate", info="The text encoder learning rate. Set to 0 or leave empty to inherit from the base optimizer " "learning rate.", interactive=True, ) self.textual_inversion_learning_rate = gr.Number( label="Textual Inversion Learning Rate", info="The textual inversion learning rate. Set to 0 or leave empty to inherit from the base " "optimizer learning rate.", interactive=True, ) self.ti_train_steps_ratio = gr.Number(label="Textual Inversion Train Steps Ratio", interactive=True) with gr.Row(): self.lr_scheduler = gr.Dropdown( label="Learning Rate Scheduler", choices=get_typing_literal_options(SdxlLoraAndTextualInversionConfig, "lr_scheduler"), interactive=True, ) self.lr_warmup_steps = gr.Number( label="Warmup Steps", info="The number of warmup steps in the " "learning rate schedule, if applicable to the selected scheduler.", interactive=True, ) with gr.Row(): self.use_masks = gr.Checkbox( label="Use Masks", info="This can only be enabled if the dataset contains masks.", interactive=True ) with gr.Tab("Advanced"): with gr.Column(): self.lora_rank_dim = gr.Number( label="LoRA Rank Dim", info="The rank dimension to use for the LoRA layers. Increasing the rank dimension increases" " the model's expressivity, but also increases the size of the generated LoRA model.", interactive=True, precision=0, ) self.min_snr_gamma = gr.Number( label="Minumum SNR Gamma", info="min_snr_gamma acts like an an upper bound on the weight of samples with low noise " "levels. If None, then Min-SNR weighting will not be applied. If enabled, the recommended " "value is min_snr gamma = 5.0.", interactive=True, ) self.max_grad_norm = gr.Number( label="Max Gradient Norm", info="Max gradient norm for clipping. Set to 0 or leave empty for no clipping (null).", interactive=True, ) self.train_batch_size = gr.Number( label="Batch Size", info="The Training Batch Size - Higher values require increasing amounts of VRAM.", precision=0, interactive=True, ) gr.Markdown("## Validation") with gr.Group(): self.validation_prompts = gr.Textbox( label="Validation Prompts", info="Enter one validation prompt per line. Optionally, add negative prompts after a '[NEG]' " "delimiter. For example: `positive prompt[NEG]negative prompt`. ", lines=5, interactive=True, ) self.num_validation_images_per_prompt = gr.Number( label="# of Validation Images to Generate per Prompt", precision=0, interactive=True ) def update_ui_components_with_config_data( self, config: SdxlLoraAndTextualInversionConfig ) -> dict[gr.components.Component, typing.Any]: update_dict = { self.model: config.model, self.hf_variant: config.hf_variant, self.vae_model: config.vae_model, self.num_vectors: config.num_vectors, self.placeholder_token: config.placeholder_token, self.initializer_token: config.initializer_token, self.initial_phrase: config.initial_phrase, self.max_checkpoints: config.max_checkpoints, self.train_unet: config.train_unet, self.train_text_encoder: config.train_text_encoder, self.train_ti: config.train_ti, self.unet_learning_rate: config.unet_learning_rate, self.text_encoder_learning_rate: config.text_encoder_learning_rate, self.textual_inversion_learning_rate: config.textual_inversion_learning_rate, self.ti_train_steps_ratio: config.ti_train_steps_ratio, self.lr_scheduler: config.lr_scheduler, self.lr_warmup_steps: config.lr_warmup_steps, self.use_masks: config.use_masks, self.max_grad_norm: config.max_grad_norm, self.train_batch_size: config.train_batch_size, self.cache_text_encoder_outputs: config.cache_text_encoder_outputs, self.cache_vae_outputs: config.cache_vae_outputs, self.enable_cpu_offload_during_validation: config.enable_cpu_offload_during_validation, self.gradient_accumulation_steps: config.gradient_accumulation_steps, self.weight_dtype: config.weight_dtype, self.gradient_checkpointing: config.gradient_checkpointing, self.lora_rank_dim: config.lora_rank_dim, self.min_snr_gamma: config.min_snr_gamma, self.validation_prompts: convert_pos_neg_prompts_to_ui_prompts( config.validation_prompts, config.negative_validation_prompts ), self.num_validation_images_per_prompt: config.num_validation_images_per_prompt, } update_dict.update( self.image_caption_sd_data_loader_config_group.update_ui_components_with_config_data(config.data_loader) ) update_dict.update(self.base_pipeline_config_group.update_ui_components_with_config_data(config)) update_dict.update(self.optimizer_config_group.update_ui_components_with_config_data(config.optimizer)) # Sanity check to catch if we accidentally forget to update a UI component. assert set(update_dict.keys()) == set(self.get_ui_output_components()) return update_dict def update_config_with_ui_component_data( self, orig_config: SdxlLoraAndTextualInversionConfig, ui_data: dict[gr.components.Component, typing.Any] ) -> SdxlLoraAndTextualInversionConfig: new_config = orig_config.model_copy(deep=True) new_config.model = ui_data.pop(self.model) new_config.hf_variant = ui_data.pop(self.hf_variant) or None new_config.vae_model = ui_data.pop(self.vae_model) or None new_config.num_vectors = ui_data.pop(self.num_vectors) new_config.placeholder_token = ui_data.pop(self.placeholder_token) new_config.initializer_token = ui_data.pop(self.initializer_token) or None new_config.initial_phrase = ui_data.pop(self.initial_phrase) or None new_config.max_checkpoints = ui_data.pop(self.max_checkpoints) new_config.train_unet = ui_data.pop(self.train_unet) new_config.train_text_encoder = ui_data.pop(self.train_text_encoder) new_config.train_ti = ui_data.pop(self.train_ti) unet_lr_value = ui_data.pop(self.unet_learning_rate) new_config.unet_learning_rate = None if unet_lr_value == 0 else unet_lr_value text_encoder_lr_value = ui_data.pop(self.text_encoder_learning_rate) new_config.text_encoder_learning_rate = None if text_encoder_lr_value == 0 else text_encoder_lr_value ti_lr_value = ui_data.pop(self.textual_inversion_learning_rate) new_config.textual_inversion_learning_rate = None if ti_lr_value == 0 else ti_lr_value new_config.ti_train_steps_ratio = ui_data.pop(self.ti_train_steps_ratio) new_config.lr_scheduler = ui_data.pop(self.lr_scheduler) new_config.lr_warmup_steps = ui_data.pop(self.lr_warmup_steps) new_config.use_masks = ui_data.pop(self.use_masks) max_grad_norm_value = ui_data.pop(self.max_grad_norm) new_config.max_grad_norm = None if max_grad_norm_value == 0 else max_grad_norm_value new_config.train_batch_size = ui_data.pop(self.train_batch_size) new_config.cache_text_encoder_outputs = ui_data.pop(self.cache_text_encoder_outputs) new_config.cache_vae_outputs = ui_data.pop(self.cache_vae_outputs) new_config.enable_cpu_offload_during_validation = ui_data.pop(self.enable_cpu_offload_during_validation) new_config.gradient_accumulation_steps = ui_data.pop(self.gradient_accumulation_steps) new_config.weight_dtype = ui_data.pop(self.weight_dtype) new_config.gradient_checkpointing = ui_data.pop(self.gradient_checkpointing) new_config.lora_rank_dim = ui_data.pop(self.lora_rank_dim) new_config.min_snr_gamma = ui_data.pop(self.min_snr_gamma) new_config.num_validation_images_per_prompt = ui_data.pop(self.num_validation_images_per_prompt) positive_prompts, negative_prompts = convert_ui_prompts_to_pos_neg_prompts(ui_data.pop(self.validation_prompts)) new_config.validation_prompts = positive_prompts new_config.negative_validation_prompts = negative_prompts new_config.data_loader = self.image_caption_sd_data_loader_config_group.update_config_with_ui_component_data( new_config.data_loader, ui_data ) new_config = self.base_pipeline_config_group.update_config_with_ui_component_data(new_config, ui_data) new_config.optimizer = self.optimizer_config_group.update_config_with_ui_component_data( new_config.optimizer, ui_data ) # We pop items from ui_data as we use them so that we can sanity check that all the input data was transferred # to the config. assert len(ui_data) == 0 return new_config ================================================ FILE: src/invoke_training/ui/config_groups/sdxl_lora_config_group.py ================================================ import typing import gradio as gr from invoke_training.pipelines.stable_diffusion_xl.lora.config import SdxlLoraConfig from invoke_training.ui.config_groups.base_pipeline_config_group import BasePipelineConfigGroup from invoke_training.ui.config_groups.image_caption_sd_data_loader_config_group import ( ImageCaptionSDDataLoaderConfigGroup, ) from invoke_training.ui.config_groups.optimizer_config_group import OptimizerConfigGroup from invoke_training.ui.config_groups.ui_config_element import UIConfigElement from invoke_training.ui.utils.prompts import ( convert_pos_neg_prompts_to_ui_prompts, convert_ui_prompts_to_pos_neg_prompts, ) from invoke_training.ui.utils.utils import get_typing_literal_options class SdxlLoraConfigGroup(UIConfigElement): def __init__(self): """The SD_LORA configs.""" gr.Markdown("## Basic Configs") with gr.Row(): with gr.Column(scale=1): with gr.Tab("Base Model"): self.model = gr.Textbox( label="Model", info="The base model. Can be a Hugging Face Hub model name, or a path to a local model (in " "diffusers or checkpoint format).", type="text", interactive=True, ) self.hf_variant = gr.Textbox( label="Variant", info="(optional) The Hugging Face hub model variant (e.g., fp16, fp32) to use if the model is a" " HF Hub model name.", type="text", interactive=True, ) self.vae_model = gr.Textbox( label="VAE Model", info="(optional) If set, this overrides the base model's default VAE model.", type="text", interactive=True, ) with gr.Column(scale=3): with gr.Tab("Training Outputs"): self.base_pipeline_config_group = BasePipelineConfigGroup() self.max_checkpoints = gr.Number( label="Maximum Number of Checkpoints", info="The maximum number of checkpoints to keep on disk from this training run. Earlier " "checkpoints will be deleted to respect this limit.", interactive=True, precision=0, ) gr.Markdown("## Data Configs") self.image_caption_sd_data_loader_config_group = ImageCaptionSDDataLoaderConfigGroup() gr.Markdown("## Optimizer Configs") self.optimizer_config_group = OptimizerConfigGroup() gr.Markdown("## Speed / Memory Configs") with gr.Group(): with gr.Row(): self.gradient_accumulation_steps = gr.Number( label="Gradient Accumulation Steps", info="The number of gradient steps to accumulate before each weight update. This is an alternative" "to increasing the batch size when training with limited VRAM." "effective_batch_size = train_batch_size * gradient_accumulation_steps.", precision=0, interactive=True, ) with gr.Row(): self.weight_dtype = gr.Dropdown( label="Weight Type", info="The precision of the model weights. Lower precision can speed up training and reduce memory, " "with increased risk of numerical stability issues. 'bfloat16' is recommended for most use cases " "if your GPU supports it.", choices=get_typing_literal_options(SdxlLoraConfig, "weight_dtype"), interactive=True, ) with gr.Row(): self.cache_text_encoder_outputs = gr.Checkbox( label="Cache Text Encoder Outputs", info="Cache the text encoder outputs to increase speed. This should not be used when training the " "text encoder or performing data augmentations that would change the text encoder outputs.", interactive=True, ) self.cache_vae_outputs = gr.Checkbox( label="Cache VAE Outputs", info="Cache the VAE outputs to increase speed. This should not be used when training the UNet or " "performing data augmentations that would change the VAE outputs.", interactive=True, ) with gr.Row(): self.enable_cpu_offload_during_validation = gr.Checkbox( label="Enable CPU Offload during Validation", info="Offload models to the CPU sequentially during validation. This reduces peak VRAM " "requirements at the cost of slower validation during training.", interactive=True, ) self.gradient_checkpointing = gr.Checkbox( label="Gradient Checkpointing", info="If True, VRAM requirements are reduced at the cost of ~20% slower training", interactive=True, ) gr.Markdown("## General Training Configs") with gr.Tab("Core"): with gr.Row(): self.train_unet = gr.Checkbox(label="Train UNet", interactive=True) self.train_text_encoder = gr.Checkbox(label="Train Text Encoder", interactive=True) with gr.Row(): self.unet_learning_rate = gr.Number( label="UNet Learning Rate", info="The UNet learning rate. Set to 0 or leave empty to inherit from the base optimizer " "learning rate.", interactive=True, ) self.text_encoder_learning_rate = gr.Number( label="Text Encoder Learning Rate", info="The text encoder learning rate. Set to 0 or leave empty to inherit from the base optimizer " "learning rate.", interactive=True, ) with gr.Row(): self.lr_scheduler = gr.Dropdown( label="Learning Rate Scheduler", choices=get_typing_literal_options(SdxlLoraConfig, "lr_scheduler"), interactive=True, ) self.lr_warmup_steps = gr.Number( label="Warmup Steps", info="The number of warmup steps in the " "learning rate schedule, if applicable to the selected scheduler.", interactive=True, ) with gr.Row(): self.use_masks = gr.Checkbox( label="Use Masks", info="This can only be enabled if the dataset contains masks.", interactive=True ) with gr.Tab("Advanced"): with gr.Column(): self.lora_rank_dim = gr.Number( label="LoRA Rank Dim", info="The rank dimension to use for the LoRA layers. Increasing the rank dimension increases" " the model's expressivity, but also increases the size of the generated LoRA model.", interactive=True, precision=0, ) self.min_snr_gamma = gr.Number( label="Minumum SNR Gamma", info="min_snr_gamma acts like an an upper bound on the weight of samples with low noise " "levels. If None, then Min-SNR weighting will not be applied. If enabled, the recommended " "value is min_snr gamma = 5.0.", interactive=True, ) self.max_grad_norm = gr.Number( label="Max Gradient Norm", info="Max gradient norm for clipping. Set to 0 or leave empty for no clipping (null).", interactive=True, ) self.train_batch_size = gr.Number( label="Batch Size", info="The Training Batch Size - Higher values require increasing amounts of VRAM.", precision=0, interactive=True, ) gr.Markdown("## Validation") with gr.Group(): self.validation_prompts = gr.Textbox( label="Validation Prompts", info="Enter one validation prompt per line. Optionally, add negative prompts after a '[NEG]' " "delimiter. For example: `positive prompt[NEG]negative prompt`. ", lines=5, interactive=True, ) self.num_validation_images_per_prompt = gr.Number( label="# of Validation Images to Generate per Prompt", precision=0, interactive=True ) def update_ui_components_with_config_data( self, config: SdxlLoraConfig ) -> dict[gr.components.Component, typing.Any]: update_dict = { self.model: config.model, self.hf_variant: config.hf_variant, self.vae_model: config.vae_model, self.max_checkpoints: config.max_checkpoints, self.train_unet: config.train_unet, self.unet_learning_rate: config.unet_learning_rate, self.train_text_encoder: config.train_text_encoder, self.text_encoder_learning_rate: config.text_encoder_learning_rate, self.lr_scheduler: config.lr_scheduler, self.lr_warmup_steps: config.lr_warmup_steps, self.use_masks: config.use_masks, self.max_grad_norm: config.max_grad_norm, self.train_batch_size: config.train_batch_size, self.cache_text_encoder_outputs: config.cache_text_encoder_outputs, self.cache_vae_outputs: config.cache_vae_outputs, self.enable_cpu_offload_during_validation: config.enable_cpu_offload_during_validation, self.gradient_accumulation_steps: config.gradient_accumulation_steps, self.weight_dtype: config.weight_dtype, self.gradient_checkpointing: config.gradient_checkpointing, self.lora_rank_dim: config.lora_rank_dim, self.min_snr_gamma: config.min_snr_gamma, self.validation_prompts: convert_pos_neg_prompts_to_ui_prompts( config.validation_prompts, config.negative_validation_prompts ), self.num_validation_images_per_prompt: config.num_validation_images_per_prompt, } update_dict.update( self.image_caption_sd_data_loader_config_group.update_ui_components_with_config_data(config.data_loader) ) update_dict.update(self.base_pipeline_config_group.update_ui_components_with_config_data(config)) update_dict.update(self.optimizer_config_group.update_ui_components_with_config_data(config.optimizer)) # Sanity check to catch if we accidentally forget to update a UI component. assert set(update_dict.keys()) == set(self.get_ui_output_components()) return update_dict def update_config_with_ui_component_data( self, orig_config: SdxlLoraConfig, ui_data: dict[gr.components.Component, typing.Any] ) -> SdxlLoraConfig: new_config = orig_config.model_copy(deep=True) new_config.model = ui_data.pop(self.model) new_config.hf_variant = ui_data.pop(self.hf_variant) or None new_config.vae_model = ui_data.pop(self.vae_model) or None new_config.max_checkpoints = ui_data.pop(self.max_checkpoints) new_config.train_unet = ui_data.pop(self.train_unet) unet_lr_value = ui_data.pop(self.unet_learning_rate) new_config.unet_learning_rate = None if unet_lr_value == 0 else unet_lr_value new_config.train_text_encoder = ui_data.pop(self.train_text_encoder) text_encoder_lr_value = ui_data.pop(self.text_encoder_learning_rate) new_config.text_encoder_learning_rate = None if text_encoder_lr_value == 0 else text_encoder_lr_value new_config.lr_scheduler = ui_data.pop(self.lr_scheduler) new_config.lr_warmup_steps = ui_data.pop(self.lr_warmup_steps) new_config.use_masks = ui_data.pop(self.use_masks) max_grad_norm_value = ui_data.pop(self.max_grad_norm) new_config.max_grad_norm = None if max_grad_norm_value == 0 else max_grad_norm_value new_config.train_batch_size = ui_data.pop(self.train_batch_size) new_config.cache_text_encoder_outputs = ui_data.pop(self.cache_text_encoder_outputs) new_config.cache_vae_outputs = ui_data.pop(self.cache_vae_outputs) new_config.enable_cpu_offload_during_validation = ui_data.pop(self.enable_cpu_offload_during_validation) new_config.gradient_accumulation_steps = ui_data.pop(self.gradient_accumulation_steps) new_config.weight_dtype = ui_data.pop(self.weight_dtype) new_config.gradient_checkpointing = ui_data.pop(self.gradient_checkpointing) new_config.lora_rank_dim = ui_data.pop(self.lora_rank_dim) new_config.min_snr_gamma = ui_data.pop(self.min_snr_gamma) new_config.num_validation_images_per_prompt = ui_data.pop(self.num_validation_images_per_prompt) positive_prompts, negative_prompts = convert_ui_prompts_to_pos_neg_prompts(ui_data.pop(self.validation_prompts)) new_config.validation_prompts = positive_prompts new_config.negative_validation_prompts = negative_prompts new_config.data_loader = self.image_caption_sd_data_loader_config_group.update_config_with_ui_component_data( new_config.data_loader, ui_data ) new_config = self.base_pipeline_config_group.update_config_with_ui_component_data(new_config, ui_data) new_config.optimizer = self.optimizer_config_group.update_config_with_ui_component_data( new_config.optimizer, ui_data ) # We pop items from ui_data as we use them so that we can sanity check that all the input data was transferred # to the config. assert len(ui_data) == 0 return new_config ================================================ FILE: src/invoke_training/ui/config_groups/sdxl_textual_inversion_config_group.py ================================================ import typing import gradio as gr from invoke_training.pipelines.stable_diffusion_xl.textual_inversion.config import SdxlTextualInversionConfig from invoke_training.ui.config_groups.base_pipeline_config_group import BasePipelineConfigGroup from invoke_training.ui.config_groups.optimizer_config_group import OptimizerConfigGroup from invoke_training.ui.config_groups.textual_inversion_sd_data_loader_config_group import ( TextualInversionSDDataLoaderConfigGroup, ) from invoke_training.ui.config_groups.ui_config_element import UIConfigElement from invoke_training.ui.utils.prompts import ( convert_pos_neg_prompts_to_ui_prompts, convert_ui_prompts_to_pos_neg_prompts, ) from invoke_training.ui.utils.utils import get_typing_literal_options class SdxlTextualInversionConfigGroup(UIConfigElement): def __init__(self): """The SDXL_TEXTUAL_INVERSION configs.""" gr.Markdown("## Basic Configs") with gr.Row(): with gr.Column(scale=1): with gr.Tab("Base Model"): self.model = gr.Textbox( label="Model", info="The base model. Can be a Hugging Face Hub model name, or a path to a local model (in " "diffusers or checkpoint format).", type="text", interactive=True, ) self.hf_variant = gr.Textbox( label="Variant", info="(optional) The Hugging Face hub model variant (e.g., fp16, fp32) to use if the model is a" " HF Hub model name.", type="text", interactive=True, ) self.vae_model = gr.Textbox( label="VAE Model", info="(optional) If set, this overrides the base model's default VAE model.", type="text", interactive=True, ) with gr.Column(scale=3): with gr.Tab("Training Outputs"): self.base_pipeline_config_group = BasePipelineConfigGroup() self.max_checkpoints = gr.Number( label="Maximum Number of Checkpoints", info="The maximum number of checkpoints to keep on disk from this training run. Earlier " "checkpoints will be deleted to respect this limit.", interactive=True, precision=0, ) gr.Markdown("## Data Configs") self.textual_inversion_sd_data_loader_config_group = TextualInversionSDDataLoaderConfigGroup() gr.Markdown("## Textual Inversion Configs") self.num_vectors = gr.Number( label="Num Vectors", info="The number of TI vectors that will be trained. Can be overriden by 'Initial Phrase'.", interactive=True, precision=0, ) self.placeholder_token = gr.Textbox( label="Placeholder Token", info="The special word to associate the learned embeddings with. Choose a unique token that is unlikely to " "already exist in the tokenizer's vocabulary.", interactive=True, ) self.initializer_token = gr.Textbox( label="Initializer Token", info="Only one of 'Initializer Token' or 'Initial Phrase' should be set. A vocabulary token to use as an " "initializer for the placeholder token. It should be a single word that roughly describes the object or " "style that you're trying to train on. The initializer token ust map to a single tokenizer token.", interactive=True, ) self.initial_phrase = gr.Textbox( label="Initial Phrase", info="Only one of 'Initializer Token' or 'Initial Phrase' should be set. A phrase that will be used to " "initialize the placeholder token embedding. The phrase will be tokenized, and the corresponding " "embeddings will be used to initialize the placeholder tokens. The number of embedding vectors will be " "inferred from the length of the tokenized phrase, so keep the phrase short.", interactive=True, ) gr.Markdown("## Optimizer Configs") self.optimizer_config_group = OptimizerConfigGroup() gr.Markdown("## Speed / Memory Configs") with gr.Group(): with gr.Row(): self.gradient_accumulation_steps = gr.Number( label="Gradient Accumulation Steps", info="The number of gradient steps to accumulate before each weight update. This is an" " alternative to increasing the batch size when training with limited VRAM." "effective_batch_size = train_batch_size * gradient_accumulation_steps.", precision=0, interactive=True, ) with gr.Row(): self.weight_dtype = gr.Dropdown( label="Weight Type", info="The precision of the model weights. Lower precision can speed up training and reduce memory, " "with increased risk of numerical stability issues. 'bfloat16' is recommended for most use cases " "if your GPU supports it.", choices=get_typing_literal_options(SdxlTextualInversionConfig, "weight_dtype"), interactive=True, ) with gr.Row(): self.cache_vae_outputs = gr.Checkbox( label="Cache VAE Outputs", info="Cache the VAE outputs to increase speed. This should not be used when training the UNet or " "performing data augmentations that would change the VAE outputs.", interactive=True, ) with gr.Row(): self.enable_cpu_offload_during_validation = gr.Checkbox( label="Enable CPU Offload during Validation", info="Offload models to the CPU sequentially during validation. This reduces peak VRAM " "requirements at the cost of slower validation during training.", interactive=True, ) self.gradient_checkpointing = gr.Checkbox( label="Gradient Checkpointing", info="If True, VRAM requirements are reduced at the cost of ~20% slower training", interactive=True, ) gr.Markdown("## General Training Configs") with gr.Tab("Core"): with gr.Row(): self.lr_scheduler = gr.Dropdown( label="Learning Rate Scheduler", choices=get_typing_literal_options(SdxlTextualInversionConfig, "lr_scheduler"), interactive=True, ) self.lr_warmup_steps = gr.Number( label="Warmup Steps", info="The number of warmup steps in the " "learning rate schedule, if applicable to the selected scheduler.", interactive=True, ) with gr.Row(): self.use_masks = gr.Checkbox( label="Use Masks", info="This can only be enabled if the dataset contains masks.", interactive=True ) with gr.Tab("Advanced"): with gr.Column(): self.min_snr_gamma = gr.Number( label="Minumum SNR Gamma", info="min_snr_gamma acts like an an upper bound on the weight of samples with low noise " "levels. If None, then Min-SNR weighting will not be applied. If enabled, the recommended " "value is min_snr gamma = 5.0.", interactive=True, ) self.max_grad_norm = gr.Number( label="Max Gradient Norm", info="Max gradient norm for clipping. Set to 0 or leave empty for no clipping (null).", interactive=True, ) self.train_batch_size = gr.Number( label="Batch Size", info="The Training Batch Size - Higher values require increasing amounts of VRAM.", precision=0, interactive=True, ) gr.Markdown("## Validation") with gr.Group(): self.validation_prompts = gr.Textbox( label="Validation Prompts", info="Enter one validation prompt per line. Optionally, add negative prompts after a '[NEG]' " "delimiter. For example: `positive prompt[NEG]negative prompt`. ", lines=5, interactive=True, ) self.num_validation_images_per_prompt = gr.Number( label="# of Validation Images to Generate per Prompt", precision=0, interactive=True ) def update_ui_components_with_config_data( self, config: SdxlTextualInversionConfig ) -> dict[gr.components.Component, typing.Any]: update_dict = { self.model: config.model, self.hf_variant: config.hf_variant, self.vae_model: config.vae_model, self.num_vectors: config.num_vectors, self.placeholder_token: config.placeholder_token, self.initializer_token: config.initializer_token, self.initial_phrase: config.initial_phrase, self.max_checkpoints: config.max_checkpoints, self.lr_scheduler: config.lr_scheduler, self.lr_warmup_steps: config.lr_warmup_steps, self.use_masks: config.use_masks, self.max_grad_norm: config.max_grad_norm, self.train_batch_size: config.train_batch_size, self.cache_vae_outputs: config.cache_vae_outputs, self.enable_cpu_offload_during_validation: config.enable_cpu_offload_during_validation, self.gradient_accumulation_steps: config.gradient_accumulation_steps, self.weight_dtype: config.weight_dtype, self.gradient_checkpointing: config.gradient_checkpointing, self.min_snr_gamma: config.min_snr_gamma, self.validation_prompts: convert_pos_neg_prompts_to_ui_prompts( config.validation_prompts, config.negative_validation_prompts ), self.num_validation_images_per_prompt: config.num_validation_images_per_prompt, } update_dict.update( self.textual_inversion_sd_data_loader_config_group.update_ui_components_with_config_data(config.data_loader) ) update_dict.update(self.base_pipeline_config_group.update_ui_components_with_config_data(config)) update_dict.update(self.optimizer_config_group.update_ui_components_with_config_data(config.optimizer)) # Sanity check to catch if we accidentally forget to update a UI component. assert set(update_dict.keys()) == set(self.get_ui_output_components()) return update_dict def update_config_with_ui_component_data( self, orig_config: SdxlTextualInversionConfig, ui_data: dict[gr.components.Component, typing.Any] ) -> SdxlTextualInversionConfig: new_config = orig_config.model_copy(deep=True) new_config.model = ui_data.pop(self.model) new_config.hf_variant = ui_data.pop(self.hf_variant) or None new_config.vae_model = ui_data.pop(self.vae_model) or None new_config.num_vectors = ui_data.pop(self.num_vectors) new_config.placeholder_token = ui_data.pop(self.placeholder_token) new_config.initializer_token = ui_data.pop(self.initializer_token) or None new_config.initial_phrase = ui_data.pop(self.initial_phrase) or None new_config.max_checkpoints = ui_data.pop(self.max_checkpoints) new_config.lr_scheduler = ui_data.pop(self.lr_scheduler) new_config.lr_warmup_steps = ui_data.pop(self.lr_warmup_steps) new_config.use_masks = ui_data.pop(self.use_masks) max_grad_norm_value = ui_data.pop(self.max_grad_norm) new_config.max_grad_norm = None if max_grad_norm_value == 0 else max_grad_norm_value new_config.train_batch_size = ui_data.pop(self.train_batch_size) new_config.cache_vae_outputs = ui_data.pop(self.cache_vae_outputs) new_config.enable_cpu_offload_during_validation = ui_data.pop(self.enable_cpu_offload_during_validation) new_config.gradient_accumulation_steps = ui_data.pop(self.gradient_accumulation_steps) new_config.weight_dtype = ui_data.pop(self.weight_dtype) new_config.gradient_checkpointing = ui_data.pop(self.gradient_checkpointing) new_config.min_snr_gamma = ui_data.pop(self.min_snr_gamma) new_config.num_validation_images_per_prompt = ui_data.pop(self.num_validation_images_per_prompt) positive_prompts, negative_prompts = convert_ui_prompts_to_pos_neg_prompts(ui_data.pop(self.validation_prompts)) new_config.validation_prompts = positive_prompts new_config.negative_validation_prompts = negative_prompts new_config.data_loader = ( self.textual_inversion_sd_data_loader_config_group.update_config_with_ui_component_data( new_config.data_loader, ui_data ) ) new_config = self.base_pipeline_config_group.update_config_with_ui_component_data(new_config, ui_data) new_config.optimizer = self.optimizer_config_group.update_config_with_ui_component_data( new_config.optimizer, ui_data ) # We pop items from ui_data as we use them so that we can sanity check that all the input data was transferred # to the config. assert len(ui_data) == 0 return new_config ================================================ FILE: src/invoke_training/ui/config_groups/textual_inversion_sd_data_loader_config_group.py ================================================ from typing import Any import gradio as gr from invoke_training.config.data.data_loader_config import ( TextualInversionSDDataLoaderConfig, ) from invoke_training.ui.config_groups.aspect_ratio_bucket_config_group import AspectRatioBucketConfigGroup from invoke_training.ui.config_groups.dataset_config_group import DatasetConfigGroup from invoke_training.ui.config_groups.ui_config_element import UIConfigElement class TextualInversionSDDataLoaderConfigGroup(UIConfigElement): def __init__(self): with gr.Row(): with gr.Column(scale=1): with gr.Tab("Data Source Configs"): with gr.Group(): self.dataset = DatasetConfigGroup( allowed_types=[ "HF_HUB_IMAGE_CAPTION_DATASET", "IMAGE_CAPTION_JSONL_DATASET", "IMAGE_CAPTION_DIR_DATASET", "IMAGE_DIR_DATASET", ] ) with gr.Column(scale=3): with gr.Tab("Data Loading Configs"): with gr.Group(): self.caption_preset = gr.Dropdown( label="Caption Preset", choices=["None", "style", "object"], info="Only one of 'Caption Preset' or 'Caption Templates' should be set.\nSelect a Caption " "Preset option to use a set of pre-configured templates.", interactive=True, ) self.caption_templates = gr.Textbox( label="Caption Templates", info="Only one of 'Caption Preset' or 'Caption Templates' should be set. Enter one template" " per line. Each template should contain a single placeholder token slot indicated by '{}'," " for example 'a photo of a {}'.", lines=5, interactive=True, ) with gr.Row(): self.keep_original_captions = gr.Checkbox( label="Keep Original Captions", info="If True, the caption templates will be prepended to the original captions." " If False, the caption templates will replace the original captions.", interactive=True, ) self.shuffle_caption_delimiter = gr.Textbox( label="Shuffle Caption Delimiter", info="Set captions to split on the provided delimiter (e.g. ',') and shuffled.", interactive=True, ) with gr.Row(): self.resolution = gr.Number( label="Resolution", info="The resolution for input images. All of the images in the dataset will be" " resized to this resolution unless the aspect_ratio_buckets config is set.", precision=0, interactive=True, ) self.dataloader_num_workers = gr.Number( label="Dataloading Workers", info="Number of subprocesses to use for data loading. 0 means that the data will" " be loaded in the main process.", precision=0, interactive=True, ) with gr.Row(): self.center_crop = gr.Checkbox( label="Center Crop", info="If set, input images will be center-cropped to the target resolution. Otherwise," " input images will be randomly cropped to the target resolution.", interactive=True, ) self.random_flip = gr.Checkbox( label="Random Flip", info="If set, random flip augmentations will be applied to input images.", interactive=True, ) with gr.Tab("Aspect Ratio Bucketing Configs"): self.aspect_ratio_bucket_config_group = AspectRatioBucketConfigGroup() def update_ui_components_with_config_data( self, config: TextualInversionSDDataLoaderConfig ) -> dict[gr.components.Component, Any]: # Special handling of caption_preset to translate None to "None". caption_preset = "None" if config.caption_preset is not None: caption_preset = config.caption_preset update_dict = { self.caption_preset: caption_preset, self.caption_templates: "\n".join(config.caption_templates or []), self.keep_original_captions: config.keep_original_captions, self.shuffle_caption_delimiter: config.shuffle_caption_delimiter, self.resolution: config.resolution, self.center_crop: config.center_crop, self.random_flip: config.random_flip, self.dataloader_num_workers: config.dataloader_num_workers, } update_dict.update(self.dataset.update_ui_components_with_config_data(config.dataset)) update_dict.update( self.aspect_ratio_bucket_config_group.update_ui_components_with_config_data(config.aspect_ratio_buckets) ) return update_dict def update_config_with_ui_component_data( self, orig_config: TextualInversionSDDataLoaderConfig, ui_data: dict[gr.components.Component, Any] ) -> TextualInversionSDDataLoaderConfig: new_config = orig_config.model_copy(deep=True) # Special handling of caption_preset to translate "None" to None. caption_presets = {"None": None, "style": "style", "object": "object"} caption_preset = caption_presets[ui_data.pop(self.caption_preset)] # Special handling of caption_templates. caption_templates: list[str] = ui_data.pop(self.caption_templates).split("\n") caption_templates = [x.strip() for x in caption_templates if x.strip() != ""] or None new_config.dataset = self.dataset.update_config_with_ui_component_data(orig_config.dataset, ui_data) new_config.aspect_ratio_buckets = self.aspect_ratio_bucket_config_group.update_config_with_ui_component_data( orig_config.aspect_ratio_buckets, ui_data ) new_config.caption_preset = caption_preset new_config.caption_templates = caption_templates new_config.keep_original_captions = ui_data.pop(self.keep_original_captions) new_config.shuffle_caption_delimiter = ui_data.pop(self.shuffle_caption_delimiter) or None new_config.resolution = ui_data.pop(self.resolution) new_config.center_crop = ui_data.pop(self.center_crop) new_config.random_flip = ui_data.pop(self.random_flip) new_config.dataloader_num_workers = ui_data.pop(self.dataloader_num_workers) return new_config ================================================ FILE: src/invoke_training/ui/config_groups/ui_config_element.py ================================================ from typing import Any import gradio as gr class UIConfigElement: """A base class for UI blocks that represent a part of a config.""" def get_ui_output_components(self) -> list[gr.components.Component]: """Recursively return a list of all valid output UI components.""" all_ui_components = [] for attribute in vars(self).values(): if isinstance(attribute, (gr.components.Component, gr.Group)): all_ui_components.append(attribute) elif isinstance(attribute, UIConfigElement): all_ui_components.extend(attribute.get_ui_output_components()) return all_ui_components def get_ui_input_components(self) -> list[gr.components.Component]: """Recursively return a list of all valid input UI components.""" all_ui_components = [] for attribute in vars(self).values(): if isinstance(attribute, (gr.components.Component)): all_ui_components.append(attribute) elif isinstance(attribute, UIConfigElement): all_ui_components.extend(attribute.get_ui_input_components()) return all_ui_components def update_ui_components_with_config_data(self, config) -> dict[gr.components.Component, Any]: """Produce a dictionary of UI components to their corresponding updated data from the config.""" raise NotImplementedError() def update_config_with_ui_component_data(self, orig_config, ui_data: dict[gr.components.Component, Any]): """Update the orig_config with the data from the UI components. Return the updated config.""" raise NotImplementedError() ================================================ FILE: src/invoke_training/ui/gradio_blocks/header.py ================================================ import gradio as gr from invoke_training.ui.utils.utils import get_assets_dir_path class Header: def __init__(self): logo_path = get_assets_dir_path() / "logo.png" gr.Image( value=logo_path, label="Invoke Training App", width=200, interactive=False, container=False, ) gr.Markdown( "[Home](/)\n\n" "*Invoke Training* - [Documentation](https://invoke-ai.github.io/invoke-training/) --" " Learn more about Invoke at [invoke.com](https://www.invoke.com/)" ) ================================================ FILE: src/invoke_training/ui/gradio_blocks/pipeline_tab.py ================================================ import typing import gradio as gr import yaml from invoke_training.config.pipeline_config import PipelineConfig from invoke_training.ui.config_groups.ui_config_element import UIConfigElement from invoke_training.ui.utils.utils import load_config_from_yaml class PipelineTab: def __init__( self, name: str, default_config_file_path: str, pipeline_config_cls: typing.Type[PipelineConfig], config_group_cls: typing.Type[UIConfigElement], run_training_cb: typing.Callable[[PipelineConfig], None], app: gr.Blocks, ): """A tab for a single training pipeline type. Args: run_training_cb (typing.Callable[[PipelineConfig], None]): A callback function to run the training process. """ self._name = name self._default_config_file_path = default_config_file_path self._pipeline_config_cls = pipeline_config_cls self._run_training_cb = run_training_cb # self._default_config is the config that was last loaded from the reference config file. self._default_config = None # self._current_config is the config that was most recently generated from the UI. self._current_config = None gr.Markdown(f"# {self._name} Training Config") self.reference_config_file = gr.Textbox( label="Reference Config File Path", value=default_config_file_path, interactive=True ) reset_config_button = gr.Button(value="Reload reference config") self.pipeline_config_group = config_group_cls() gr.Markdown("## Config Output") generate_config_button = gr.Button(value="Generate Config") self._config_yaml = gr.Code(label="Config YAML", language="yaml", interactive=False) gr.Markdown( """# Run Training 'Start Training' starts the training process in the background. Check the terminal for logs. **Warning: Click 'Generate Config' to capture all of the latest changes before starting training.** """ ) run_training_button = gr.Button(value="Start Training") gr.Markdown( """# Visualize Results Once you've started training, you can see the results by launching tensorboard with the following command: ```bash tensorboard --logdir /path/to/output_dir ``` Alternatively, you can browse the output directory directly to find model checkpoints, logs, and validation images. """ ) reset_config_button.click( self.on_reset_config_button_click, inputs=self.reference_config_file, outputs=self.pipeline_config_group.get_ui_output_components() + [self._config_yaml], ) generate_config_button.click( self.on_generate_config_button_click, inputs=set(self.pipeline_config_group.get_ui_input_components()), outputs=self.pipeline_config_group.get_ui_output_components() + [self._config_yaml], ) run_training_button.click(self.on_run_training_button_click, inputs=[], outputs=[]) # On app load, reset the configs based on the default reference config file. # We'll wrap this in a try-except block to handle any errors during loading def safe_load_config(file_path): try: return self.on_reset_config_button_click(file_path) except Exception as e: print(f"Error during app.load for {self._name}: {e}") # Return empty values for all outputs to avoid UI errors output_components = self.pipeline_config_group.get_ui_output_components() + [self._config_yaml] return {comp: None for comp in output_components} app.load( safe_load_config, inputs=self.reference_config_file, outputs=self.pipeline_config_group.get_ui_output_components() + [self._config_yaml], ) def on_reset_config_button_click(self, file_path: str): try: print(f"Resetting UI configs for {self._name} to {file_path}.") default_config = load_config_from_yaml(file_path) if not isinstance(default_config, self._pipeline_config_cls): raise TypeError( f"Wrong config type. Expected '{self._pipeline_config_cls.__name__}', got " f"'{type(default_config).__name__}'." ) self._default_config = default_config self._current_config = self._default_config.model_copy(deep=True) update_dict = self.pipeline_config_group.update_ui_components_with_config_data(self._current_config) update_dict.update({self._config_yaml: None}) return update_dict except Exception as e: print(f"Error resetting config: {e}") # Return a minimal update dict to avoid UI errors if self._current_config: return { self._config_yaml: yaml.safe_dump( self._current_config.model_dump(), default_flow_style=False, sort_keys=False ) } return {self._config_yaml: f"Error loading config: {e}"} def on_generate_config_button_click(self, data: dict): try: print(f"Generating config for {self._name}.") self._current_config = self.pipeline_config_group.update_config_with_ui_component_data( self._current_config, data ) # Roundtrip to make sure that the config is valid. self._current_config = self._pipeline_config_cls.model_validate(self._current_config.model_dump()) # Update the UI to reflect the new state of the config # (in case some values were rounded or otherwise modified # in the process). update_dict = self.pipeline_config_group.update_ui_components_with_config_data(self._current_config) update_dict.update( { self._config_yaml: yaml.safe_dump( self._current_config.model_dump(), default_flow_style=False, sort_keys=False ) } ) return update_dict except Exception as e: print(f"Error generating config: {e}") # Return a minimal update dict to avoid UI errors if self._current_config: return { self._config_yaml: yaml.safe_dump( self._current_config.model_dump(), default_flow_style=False, sort_keys=False ) } return {self._config_yaml: f"Error generating config: {e}"} def on_run_training_button_click(self): self._run_training_cb(self._current_config) ================================================ FILE: src/invoke_training/ui/index.html ================================================ invoke-training
Invoke logo.

invoke-training

Invoke Training - Documentation

Learn more about Invoke at invoke.com

================================================ FILE: src/invoke_training/ui/pages/data_page.py ================================================ from pathlib import Path import gradio as gr from PIL import Image from invoke_training._shared.data.datasets.image_caption_jsonl_dataset import ( CAPTION_COLUMN_DEFAULT, IMAGE_COLUMN_DEFAULT, ImageCaptionExample, ImageCaptionJsonlDataset, ) from invoke_training._shared.utils.jsonl import save_jsonl from invoke_training.ui.gradio_blocks.header import Header IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png"] class DataPage: def __init__(self): # The dataset that is currently being edited. self._jsonl_path: str | None = None self._dataset: ImageCaptionJsonlDataset | None = None # Define the theme with dark mode as default theme = gr.themes.Default( # Optional: Customize colors, fonts, etc. # primary_hue=gr.themes.colors.blue, # ... ) theme._dark_mode = True # Custom CSS custom_css = """ .dark { /* Override the default accent color for dark mode */ --color-accent: #e6fd13 !important; --color-accent-soft: #e6fd1333 !important; /* Optional: Adjust soft accent too */ } .dark .tabs button[aria-selected="true"] { /* Keep selected tab text color override */ color: #e6fd13 !important; /* Optional: Remove background if --color-accent handles it */ /* background-color: transparent !important; */ } /* Style checkbox checkmark in dark mode when checked */ .dark input[type="checkbox"]:checked + span svg path { /* Target the SVG path inside the checked checkbox */ stroke: black !important; /* Set the checkmark color to black */ } """ # Pass the theme and css to gr.Blocks with gr.Blocks( theme=theme, css=custom_css, # Use updated CSS title="invoke-training", analytics_enabled=False, head='', ) as app: self._header = Header() gr.Markdown("# Data Annotation") gr.Markdown( "Note: This UI creates datasets in `IMAGE_CAPTION_JSONL_DATASET` format. For more information about " "this format see [the docs](https://invoke-ai.github.io/invoke-training/concepts/dataset_formats/)" ) # HACK: I use a column as a wrapper to control visbility of this group of UI elements. gr.Group sounds like # a more natural choice for this purpose, but it applies some styling that makes the group look weird. with gr.Column() as select_dataset_group: gr.Markdown("## Load Existing Dataset") with gr.Group(): self._existing_jsonl_path = gr.Textbox( label="Existing .jsonl Path", info="Enter the path to an existing dataset's .jsonl file.", placeholder="/path/to/dataset.jsonl", ) with gr.Row(): self._image_column_textbox = gr.Textbox( label="Image Column (Optional)", placeholder=IMAGE_COLUMN_DEFAULT ) self._caption_column_textbox = gr.Textbox( label="Caption Column (Optional)", placeholder=CAPTION_COLUMN_DEFAULT ) self._load_existing_dataset_button = gr.Button("Load Existing Dataset") gr.Markdown("## Create New Dataset") with gr.Group(): self._new_jsonl_path = gr.Textbox( label="New .jsonl Path", info="Enter the path for a new .jsonl file.", placeholder="/path/to/dataset.jsonl", ) self._create_new_dataset_button = gr.Button("Create New Dataset") self._select_dataset_group = select_dataset_group # HACK: I use a column as a wrapper to control visbility of this group of UI elements. gr.Group sounds like # a more natural choice for this purpose, but it applies some styling that makes the group look weird. with gr.Column(visible=False) as edit_dataset_group: with gr.Row(): self._current_jsonl_path = gr.Textbox(label="Currently editing:", interactive=False) self._change_dataset_button = gr.Button("Change") gr.Markdown("## Add Images") with gr.Group(): self._image_source_textbox = gr.Textbox( label="Image Source", info="Enter the path to a single image or a directory containing images. If a directory path " "is passed, it will be searched recursively for image files.", placeholder="/path/to/image_dir", ) self._add_images_button = gr.Button("Add Images") gr.Markdown("## Edit Captions") with gr.Row(): with gr.Column(): with gr.Row(): self._cur_example_index = gr.Number(label="Current index", precision=0, interactive=True) self._cur_len_number = gr.Number(label="Dataset length", interactive=False) with gr.Row(): self._beyond_dataset_limits_warning = gr.Markdown( "**Current index is beyond dataset limits.** If you have completed all captions, click " "'Home' to begin training." ) with gr.Row(): self._cur_image = gr.Image(value=None, label="Image", interactive=False, width=500) with gr.Column(): self._cur_caption = gr.Textbox(label="Caption", interactive=True, lines=25) with gr.Row(): self._save_and_prev_button = gr.Button("Save and Go-To Previous") self._save_and_next_button = gr.Button("Save and Go-To Next") gr.Markdown("## Raw JSONL") self._data_jsonl = gr.Code(label="Dataset .jsonl", language="json", interactive=False) self._edit_dataset_group = edit_dataset_group self._app = app standard_outputs = [ self._select_dataset_group, self._edit_dataset_group, self._current_jsonl_path, self._cur_len_number, self._cur_example_index, self._cur_image, self._cur_caption, self._beyond_dataset_limits_warning, self._data_jsonl, ] self._load_existing_dataset_button.click( self._on_load_existing_dataset_button_click, inputs=set([self._existing_jsonl_path, self._image_column_textbox, self._caption_column_textbox]), outputs=standard_outputs, ) self._create_new_dataset_button.click( self._on_create_dataset_button_click, inputs=set([self._new_jsonl_path]), outputs=standard_outputs, ) self._change_dataset_button.click( self._on_change_dataset_button_click, inputs=None, outputs=standard_outputs ) self._save_and_prev_button.click( self._on_save_and_prev_button_click, inputs=set([self._cur_example_index, self._cur_caption]), outputs=standard_outputs, ) self._save_and_next_button.click( self._on_save_and_next_button_click, inputs=set([self._cur_example_index, self._cur_caption]), outputs=standard_outputs, ) self._add_images_button.click( self._on_add_images_button_click, inputs=set([self._image_source_textbox]), outputs=standard_outputs, ) self._cur_example_index.input( self._on_cur_example_index_change, inputs=set([self._cur_example_index]), outputs=standard_outputs, ) def _update_state(self, idx: int): if self._dataset is None or self._jsonl_path is None: return { self._select_dataset_group: gr.Group(visible=True), self._edit_dataset_group: gr.Column(visible=False), self._current_jsonl_path: None, self._cur_len_number: 0, self._cur_example_index: 0, self._cur_image: None, self._cur_caption: None, self._beyond_dataset_limits_warning: gr.Markdown(visible=False), self._data_jsonl: "", } idx = idx image = None caption = None beyond_limits = True if 0 <= idx and idx < len(self._dataset): beyond_limits = False example = self._dataset[idx] image: Image.Image = example["image"] caption = example["caption"] # Resize the image to have a max dimension of 1024. On slow connections, sending the full-size image can be # very slow. max_dim = 1024 if image.width > max_dim or image.height > max_dim: scale = max_dim / max(image.width, image.height) image = image.resize((int(image.width * scale), int(image.height * scale))) jsonl_str = "\n".join([example.model_dump_json() for example in self._dataset.examples]) return { self._select_dataset_group: gr.Group(visible=self._dataset is None), self._edit_dataset_group: gr.Column(visible=self._dataset is not None), self._current_jsonl_path: str(self._jsonl_path), self._cur_len_number: len(self._dataset), self._cur_example_index: idx, self._cur_image: image, self._cur_caption: caption, self._beyond_dataset_limits_warning: gr.Markdown(visible=beyond_limits), self._data_jsonl: jsonl_str, } def _on_load_existing_dataset_button_click(self, data: dict): """Load an existing dataset.""" jsonl_path = Path(data[self._existing_jsonl_path]) jsonl_path = jsonl_path.resolve() if not jsonl_path.exists(): raise ValueError(f"'{jsonl_path}' does not exist.") self._jsonl_path = jsonl_path self._dataset = ImageCaptionJsonlDataset( jsonl_path=jsonl_path, image_column=data[self._image_column_textbox] or IMAGE_COLUMN_DEFAULT, caption_column=data[self._caption_column_textbox] or CAPTION_COLUMN_DEFAULT, ) return self._update_state(0) def _on_create_dataset_button_click(self, data: dict): """Create a new dataset.""" jsonl_path = Path(data[self._new_jsonl_path]) jsonl_path = jsonl_path.resolve() if jsonl_path.exists(): raise ValueError(f"'{jsonl_path}' already exists.") if jsonl_path.suffix != ".jsonl": raise ValueError("Invalid file extension. Expected '.jsonl'.") print(f"Creating new dataset at '{jsonl_path}'.") jsonl_path.parent.mkdir(parents=True, exist_ok=True) # Create an empty jsonl file. save_jsonl([], jsonl_path) self._jsonl_path = jsonl_path self._dataset = ImageCaptionJsonlDataset(jsonl_path=jsonl_path) return self._update_state(0) def _on_change_dataset_button_click(self): self._jsonl_path = None self._dataset = None return self._update_state(0) def _on_save_and_go_button_click(self, data: dict, idx_change: int): # Update the current caption and re-save the jsonl file. idx: int = data[self._cur_example_index] if idx < 0 or idx >= len(self._dataset): # idx is out of bounds, so don't update the caption, but still change the index. return self._update_state(idx + idx_change) print(f"Updating caption for example {idx} of '{self._jsonl_path}'.") caption = data[self._cur_caption] self._dataset.examples[idx].caption = caption self._dataset.save_jsonl() return self._update_state(idx + idx_change) def _on_save_and_next_button_click(self, data: dict): return self._on_save_and_go_button_click(data, 1) def _on_save_and_prev_button_click(self, data: dict): return self._on_save_and_go_button_click(data, -1) def _on_cur_example_index_change(self, data: dict): return self._update_state(data[self._cur_example_index]) def _on_add_images_button_click(self, data: dict): """Add images to the dataset.""" image_source_path = Path(data[self._image_source_textbox]) if not image_source_path.exists(): raise ValueError(f"'{image_source_path}' does not exist.") # Determine the list of image paths to add to the dataset. image_paths = [] if image_source_path.is_file(): if image_source_path.suffix.lower() not in IMAGE_EXTENSIONS: raise ValueError( f"'{image_source_path}' is not a valid image file. Expected one of {IMAGE_EXTENSIONS}." ) image_paths.append(image_source_path.resolve()) else: # Recursively search for image files in the image_source_path directory. for file_path in image_source_path.glob("**/*"): if file_path.is_file() and file_path.suffix.lower() in IMAGE_EXTENSIONS: image_paths.append(file_path.resolve()) # Avoid adding duplicate images. cur_image_paths = set([Path(example.image_path) for example in self._dataset.examples]) image_paths = set(image_paths) new_image_paths = image_paths - cur_image_paths if len(new_image_paths) < len(image_paths): print(f"Skipping {len(image_paths) - len(new_image_paths)} images that are already in the dataset.") # Add the new images to the dataset. print(f"Adding {len(new_image_paths)} images to '{self._jsonl_path}'.") for image_path in new_image_paths: self._dataset.examples.append(ImageCaptionExample(image_path=str(image_path), caption="")) # Save the updated dataset. self._dataset.save_jsonl() return self._update_state(0) def app(self): return self._app ================================================ FILE: src/invoke_training/ui/pages/training_page.py ================================================ import os import subprocess import tempfile import time import gradio as gr import yaml from invoke_training.config.pipeline_config import PipelineConfig from invoke_training.pipelines.flux.lora.config import FluxLoraConfig from invoke_training.pipelines.stable_diffusion.lora.config import SdLoraConfig from invoke_training.pipelines.stable_diffusion.textual_inversion.config import SdTextualInversionConfig from invoke_training.pipelines.stable_diffusion_xl.finetune.config import SdxlFinetuneConfig from invoke_training.pipelines.stable_diffusion_xl.lora.config import SdxlLoraConfig from invoke_training.pipelines.stable_diffusion_xl.lora_and_textual_inversion.config import ( SdxlLoraAndTextualInversionConfig, ) from invoke_training.pipelines.stable_diffusion_xl.textual_inversion.config import SdxlTextualInversionConfig from invoke_training.ui.config_groups.flux_lora_config_group import FluxLoraConfigGroup from invoke_training.ui.config_groups.sd_lora_config_group import SdLoraConfigGroup from invoke_training.ui.config_groups.sd_textual_inversion_config_group import SdTextualInversionConfigGroup from invoke_training.ui.config_groups.sdxl_finetune_config_group import SdxlFinetuneConfigGroup from invoke_training.ui.config_groups.sdxl_lora_and_textual_inversion_config_group import ( SdxlLoraAndTextualInversionConfigGroup, ) from invoke_training.ui.config_groups.sdxl_lora_config_group import SdxlLoraConfigGroup from invoke_training.ui.config_groups.sdxl_textual_inversion_config_group import SdxlTextualInversionConfigGroup from invoke_training.ui.gradio_blocks.header import Header from invoke_training.ui.gradio_blocks.pipeline_tab import PipelineTab from invoke_training.ui.utils.utils import get_config_dir_path class TrainingPage: def __init__(self): self._config_temp_directory = tempfile.TemporaryDirectory() self._training_process = None # Define the theme with dark mode as default theme = gr.themes.Default() theme._dark_mode = True # Custom CSS custom_css = """ .dark { /* Override the default accent color for dark mode */ --color-accent: #e6fd13 !important; --color-accent-soft: #e6fd1333 !important; /* Optional: Adjust soft accent too */ } .dark .tabs button[aria-selected="true"] { /* Keep selected tab text color override */ color: #e6fd13 !important; } /* Style checkbox checkmark in dark mode when checked */ .dark input[type="checkbox"]:checked + span svg path { /* Target the SVG path inside the checked checkbox */ stroke: black !important; /* Set the checkmark color to black */ } """ # Pass the theme and css to gr.Blocks with gr.Blocks( theme=theme, css=custom_css, title="invoke-training", analytics_enabled=False, head=""" """, ) as app: self._header = Header() with gr.Tab(label="SD LoRA"): PipelineTab( name="SD LoRA", default_config_file_path=str(get_config_dir_path() / "sd_lora_baroque_1x8gb.yaml"), pipeline_config_cls=SdLoraConfig, config_group_cls=SdLoraConfigGroup, run_training_cb=self._run_training, app=app, ) with gr.Tab(label="SDXL LoRA"): PipelineTab( name="SDXL LoRA", default_config_file_path=str(get_config_dir_path() / "sdxl_lora_baroque_1x24gb.yaml"), pipeline_config_cls=SdxlLoraConfig, config_group_cls=SdxlLoraConfigGroup, run_training_cb=self._run_training, app=app, ) with gr.Tab(label="SD Textual Inversion"): PipelineTab( name="SD Textual Inversion", default_config_file_path=str(get_config_dir_path() / "sd_textual_inversion_gnome_1x8gb.yaml"), pipeline_config_cls=SdTextualInversionConfig, config_group_cls=SdTextualInversionConfigGroup, run_training_cb=self._run_training, app=app, ) with gr.Tab(label="SDXL Textual Inversion"): PipelineTab( name="SDXL Textual Inversion", default_config_file_path=str(get_config_dir_path() / "sdxl_textual_inversion_gnome_1x24gb.yaml"), pipeline_config_cls=SdxlTextualInversionConfig, config_group_cls=SdxlTextualInversionConfigGroup, run_training_cb=self._run_training, app=app, ) with gr.Tab(label="SDXL LoRA and Textual Inversion"): PipelineTab( name="SDXL LoRA and Textual Inversion", default_config_file_path=str(get_config_dir_path() / "sdxl_lora_and_ti_gnome_1x24gb.yaml"), pipeline_config_cls=SdxlLoraAndTextualInversionConfig, config_group_cls=SdxlLoraAndTextualInversionConfigGroup, run_training_cb=self._run_training, app=app, ) with gr.Tab(label="SDXL Finetune"): PipelineTab( name="SDXL Finetune", default_config_file_path=str(get_config_dir_path() / "sdxl_finetune_baroque_1x24gb.yaml"), pipeline_config_cls=SdxlFinetuneConfig, config_group_cls=SdxlFinetuneConfigGroup, run_training_cb=self._run_training, app=app, ) with gr.Tab(label="Flux LoRA"): PipelineTab( name="Flux LoRA", default_config_file_path=str( get_config_dir_path() / "flux_lora_1x40gb.yaml" ), # Changed from 8gb to 40gb # noqa: E501 pipeline_config_cls=FluxLoraConfig, config_group_cls=FluxLoraConfigGroup, run_training_cb=self._run_training, app=app, ) self._app = app def app(self): return self._app def _run_training(self, config: PipelineConfig): # Check if there is already a training process running. if self._training_process is not None: if self._training_process.poll() is None: print( "Tried to start a new training process, but another training process is already running. " "Terminate the existing process first." ) return else: self._training_process = None print(f"Starting {config.type} training...") # Write the config to a temporary config file where the training subprocess can read it. timestamp = str(time.time()).replace(".", "_") config_path = os.path.join(self._config_temp_directory.name, f"{timestamp}.yaml") with open(config_path, "w") as f: yaml.safe_dump(config.model_dump(), f, default_flow_style=False, sort_keys=False) self._training_process = subprocess.Popen(["invoke-train", "-c", str(config_path)]) print(f"Started {config.type} training.") ================================================ FILE: src/invoke_training/ui/utils/prompts.py ================================================ NEGATIVE_PROMPT_DELIMITER = "[NEG]" def split_pos_neg_prompts(prompt: str) -> tuple[str, str]: """Split a prompt containing a '[NEG]' delimiter into a positive prompt and a negative prompt. Examples: - 'positive prompt[NEG]negative prompt' -> ('positive prompt', 'negative prompt') - 'positive prompt' -> ('positive prompt', '') - 'positive prompt[NEG]negative[NEG]prompt' -> Raises ValueError """ prompt = prompt.strip() splits = prompt.split(NEGATIVE_PROMPT_DELIMITER) if len(splits) == 1: # This is a positive prompt only. return splits[0], "" elif len(splits) == 2: # This is a positive prompt followed by a negative prompt. return splits[0], splits[1] raise ValueError( f"Failed to split the prompt into a positive and negative prompt. Expected at most one " f"'{NEGATIVE_PROMPT_DELIMITER}' delimiter. Prompt: '{prompt}'." ) def merge_pos_neg_prompts(positive_prompt: str, negative_prompt: str) -> str: """Merge a positive prompt and a negative prompt into a single prompt of the form: 'positive prompt[NEG]negative prompt' """ if NEGATIVE_PROMPT_DELIMITER in positive_prompt: raise ValueError( f"Positive prompt cannot contain the '{NEGATIVE_PROMPT_DELIMITER}' delimiter. Prompt: '{positive_prompt}'" ) if NEGATIVE_PROMPT_DELIMITER in negative_prompt: raise ValueError( f"Negative prompt cannot contain the '{NEGATIVE_PROMPT_DELIMITER}' delimiter. Prompt: '{negative_prompt}'" ) if negative_prompt == "": return positive_prompt return f"{positive_prompt}{NEGATIVE_PROMPT_DELIMITER}{negative_prompt}" def convert_ui_prompts_to_pos_neg_prompts(prompts: str) -> tuple[list[str], list[str] | None]: """Convert prompts from the UI textbox format to lists of positive and negative prompts.""" ui_prompt_list = prompts.split("\n") positive_prompts = [] negative_prompts = [] for prompt in ui_prompt_list: positive_prompt, negative_prompt = split_pos_neg_prompts(prompt) # Skip empty lines. if positive_prompt == "" and negative_prompt == "": continue positive_prompts.append(positive_prompt) negative_prompts.append(negative_prompt) # Convert negative_prompts to None if all negative prompts are empty. if all([p == "" for p in negative_prompts]): negative_prompts = None return positive_prompts, negative_prompts def convert_pos_neg_prompts_to_ui_prompts(positive_prompts: list[str], negative_prompts: list[str] | None) -> str: """Convert lists of positive and negative prompts to the UI textbox format.""" if negative_prompts is None: negative_prompts = [""] * len(positive_prompts) ui_prompts = "" for positive_prompt, negative_prompt in zip(positive_prompts, negative_prompts, strict=True): ui_prompts += merge_pos_neg_prompts(positive_prompt, negative_prompt) + "\n" return ui_prompts.strip() ================================================ FILE: src/invoke_training/ui/utils/utils.py ================================================ import typing from pathlib import Path import yaml from pydantic import TypeAdapter from invoke_training.config.pipeline_config import PipelineConfig def get_config_dir_path() -> Path: p = Path(__file__).parent.parent.parent / "sample_configs" if not p.exists(): raise FileNotFoundError(f"Config directory not found: '{p}'") return p def get_assets_dir_path() -> Path: p = Path(__file__).parent.parent.parent / "assets" if not p.exists(): pass return p def load_config_from_yaml(file_path: Path | str) -> PipelineConfig: file_path = Path(file_path) with open(file_path, "r") as f: cfg = yaml.safe_load(f) pipeline_adapter: TypeAdapter[PipelineConfig] = TypeAdapter(PipelineConfig) train_config = pipeline_adapter.validate_python(cfg) return train_config def get_typing_literal_options(cls, field_name: str) -> list[str]: literal_type_hint = typing.get_type_hints(cls)[field_name] return list(typing.get_args(literal_type_hint)) ================================================ FILE: tests/invoke_training/_shared/__init__.py ================================================ ================================================ FILE: tests/invoke_training/_shared/checkpoints/test_checkpoint_tracker.py ================================================ import os import tempfile from pathlib import Path import pytest from invoke_training._shared.checkpoints.checkpoint_tracker import CheckpointTracker def test_checkpoint_tracker_get_path_file(): """Test the CheckpointTracker.get_path(...) method with an extension.""" checkpoint_tracker = CheckpointTracker( base_dir="base_dir", prefix="prefix", extension=".ckpt", index_padding=8, ) path = checkpoint_tracker.get_path(epoch=1, step=55) assert Path(path) == Path("base_dir/prefix-epoch_00000001-step_00000055.ckpt") def test_checkpoint_tracker_get_path_directory(): """Test the CheckpointTracker.get_path(...) method without an extension.""" checkpoint_tracker = CheckpointTracker( base_dir="base_dir", prefix="prefix", extension=None, index_padding=8, ) path = checkpoint_tracker.get_path(epoch=1, step=55) assert Path(path) == Path("base_dir/prefix-epoch_00000001-step_00000055") def test_checkpoint_tracker_bad_extension(): """Test that CheckpointTracker raises a ValueError if an attempt is made to initialize it with an invalid extension. """ with pytest.raises(ValueError): _ = CheckpointTracker(base_dir="base_dir", prefix="prefix", extension="ckpt") def test_checkpoint_tracker_prune_files(): """Test the CheckpointTracker.prune() method with checkpoint files.""" with tempfile.TemporaryDirectory() as dir_name: checkpoint_tracker = CheckpointTracker(base_dir=dir_name, prefix="prefix", extension=".ckpt", max_checkpoints=5) # Create 6 checkpoints. for i in range(6): path = checkpoint_tracker.get_path(epoch=0, step=i) with open(path, "w") as f: f.write("hi") # Prune the 3 checkpoints with the lowest step counts. num_pruned = checkpoint_tracker.prune(2) assert num_pruned == 3 # Verify that the correct checkpoints were pruned. assert all([not os.path.exists(checkpoint_tracker.get_path(epoch=0, step=i)) for i in range(3)]) assert all([os.path.exists(checkpoint_tracker.get_path(epoch=0, step=i)) for i in range(3, 6)]) def test_checkpoint_tracker_prune_directories(): """Test the CheckpointTracker.prune() method with checkpoint directories.""" with tempfile.TemporaryDirectory() as dir_name: checkpoint_tracker = CheckpointTracker(base_dir=dir_name, prefix="prefix", extension=None, max_checkpoints=5) # Create 6 checkpoints. for i in range(6): path = checkpoint_tracker.get_path(epoch=0, step=i) # Create checkpoint directory and add file to it. os.makedirs(path) with open(os.path.join(path, "tmp.txt"), "w") as f: f.write("hi") # Prune the 3 checkpoints with lowest indices. num_pruned = checkpoint_tracker.prune(2) assert num_pruned == 3 # Verify that the correct checkpoints were pruned. assert all([not os.path.exists(checkpoint_tracker.get_path(epoch=0, step=i)) for i in range(3)]) assert all([os.path.exists(checkpoint_tracker.get_path(epoch=0, step=i)) for i in range(3, 6)]) def test_checkpoint_tracker_prune_no_max(): """Test that CheckpointTracker.prune() is a no-op when max_checkpoints is None.""" with tempfile.TemporaryDirectory() as dir_name: checkpoint_tracker = CheckpointTracker( base_dir=dir_name, prefix="prefix", extension=".ckpt", max_checkpoints=None ) # Create 6 checkpoints. for i in range(6): path = checkpoint_tracker.get_path(epoch=0, step=i) with open(path, "w") as f: f.write("hi") # Call prune, which should have no effect. num_pruned = checkpoint_tracker.prune(2) assert num_pruned == 0 # Verify that no checkpoints were deleted. assert all([os.path.exists(checkpoint_tracker.get_path(epoch=0, step=i)) for i in range(6)]) ================================================ FILE: tests/invoke_training/_shared/checkpoints/test_serialization.py ================================================ import os import tempfile import pytest import torch from invoke_training._shared.checkpoints.serialization import ( load_state_dict, save_state_dict, ) @pytest.mark.parametrize("file_name", ["state.ckpt", "state.pt", "state.safetensors"]) def test_state_dict_save_and_load_roundtrip(file_name): with tempfile.TemporaryDirectory() as dir_name: file_path = os.path.join(dir_name, file_name) in_state_dict = {"a": torch.Tensor([1.0, 2.0])} # Perform save-load roundtrip. save_state_dict(in_state_dict, file_path) out_state_dict = load_state_dict(file_path) assert len(in_state_dict) == len(out_state_dict) for key in in_state_dict: assert torch.equal(in_state_dict[key], out_state_dict[key]) def test_save_state_dict_bad_extension(): """Test that save_state_dict(...) raises a ValueError if it receives an unsupported file extension.""" with pytest.raises(ValueError): save_state_dict({}, "state.txt") def test_load_state_dict_bad_extension(): """Test that load_state_dict(...) raises a ValueError if it receives an unsupported file extension.""" with pytest.raises(ValueError): load_state_dict("state.txt") ================================================ FILE: tests/invoke_training/_shared/data/__init__.py ================================================ ================================================ FILE: tests/invoke_training/_shared/data/data_loaders/__init__.py ================================================ ================================================ FILE: tests/invoke_training/_shared/data/data_loaders/test_dreambooth_sd_dataloader.py ================================================ import torch from invoke_training._shared.data.data_loaders.dreambooth_sd_dataloader import ( build_dreambooth_sd_dataloader, ) from invoke_training.config.data.data_loader_config import AspectRatioBucketConfig, DreamboothSDDataLoaderConfig from invoke_training.config.data.dataset_config import ImageDirDatasetConfig from ..dataset_fixtures import image_dir # noqa: F401 def test_build_dreambooth_sd_dataloader(image_dir): # noqa: F811 """Smoke test of build_dreambooth_sd_dataloader(...).""" config = DreamboothSDDataLoaderConfig( instance_caption="test instance prompt", instance_dataset=ImageDirDatasetConfig(dataset_dir=str(image_dir)), class_caption="test class prompt", # For testing, we just use the same directory for the instance and class datasets. class_dataset=ImageDirDatasetConfig(dataset_dir=str(image_dir)), ) data_loader = build_dreambooth_sd_dataloader(config=config, batch_size=2) assert len(data_loader) == 5 # (5 class images + 5 instance images) / batch size 2 example = next(iter(data_loader)) assert set(example.keys()) == {"image", "id", "caption", "original_size_hw", "crop_top_left_yx", "loss_weight"} image = example["image"] assert image.shape == (2, 3, 512, 512) assert image.dtype == torch.float32 original_size_hw = example["original_size_hw"] assert len(original_size_hw) == 2 assert len(original_size_hw[0]) == 2 crop_top_left_yx = example["crop_top_left_yx"] assert len(crop_top_left_yx) == 2 assert len(crop_top_left_yx[0]) == 2 caption = example["caption"] assert caption == ["test instance prompt", "test class prompt"] loss_weight = example["loss_weight"] assert loss_weight.shape == (2,) assert loss_weight.dtype == torch.float32 def test_build_dreambooth_sd_dataloader_no_class_dataset(image_dir): # noqa: F811 """Smoke test of build_dreambooth_sd_dataloader(...) without a class dataset.""" config = DreamboothSDDataLoaderConfig( instance_caption="test instance prompt", instance_dataset=ImageDirDatasetConfig(dataset_dir=str(image_dir)), ) data_loader = build_dreambooth_sd_dataloader(config=config, batch_size=2) assert len(data_loader) == 3 # 5 instance images, batch size 2 example = next(iter(data_loader)) assert set(example.keys()) == {"image", "id", "caption", "original_size_hw", "crop_top_left_yx", "loss_weight"} image = example["image"] assert image.shape == (2, 3, 512, 512) assert image.dtype == torch.float32 original_size_hw = example["original_size_hw"] assert len(original_size_hw) == 2 assert len(original_size_hw[0]) == 2 crop_top_left_yx = example["crop_top_left_yx"] assert len(crop_top_left_yx) == 2 assert len(crop_top_left_yx[0]) == 2 caption = example["caption"] assert caption == ["test instance prompt", "test instance prompt"] loss_weight = example["loss_weight"] assert loss_weight.shape == (2,) assert loss_weight.dtype == torch.float32 def test_build_dreambooth_sd_dataloader_with_bucketing(image_dir): # noqa: F811 """Smoke test of build_dreambooth_sd_dataloader(...).""" config = DreamboothSDDataLoaderConfig( instance_caption="test instance prompt", instance_dataset=ImageDirDatasetConfig(dataset_dir=str(image_dir)), class_caption="test class prompt", # For testing, we just use the same directory for the instance and class datasets. class_dataset=ImageDirDatasetConfig(dataset_dir=str(image_dir)), aspect_ratio_buckets=AspectRatioBucketConfig( target_resolution=256, start_dim=128, end_dim=512, divisible_by=64 ), ) data_loader = build_dreambooth_sd_dataloader(config=config, batch_size=2, shuffle=False, sequential_batching=True) assert len(data_loader) == 6 # 5 class images -> 3 batches + 5 instance images -> 3 batches example = next(iter(data_loader)) assert set(example.keys()) == {"image", "id", "caption", "original_size_hw", "crop_top_left_yx", "loss_weight"} image = example["image"] assert image.shape == (2, 3, 256, 256) assert image.dtype == torch.float32 original_size_hw = example["original_size_hw"] assert len(original_size_hw) == 2 assert len(original_size_hw[0]) == 2 crop_top_left_yx = example["crop_top_left_yx"] assert len(crop_top_left_yx) == 2 assert len(crop_top_left_yx[0]) == 2 caption = example["caption"] assert caption == ["test instance prompt", "test instance prompt"] loss_weight = example["loss_weight"] assert loss_weight.shape == (2,) assert loss_weight.dtype == torch.float32 ================================================ FILE: tests/invoke_training/_shared/data/data_loaders/test_image_caption_sd_dataloader.py ================================================ import math import torch from invoke_training._shared.data.data_loaders.image_caption_sd_dataloader import build_image_caption_sd_dataloader from invoke_training.config.data.data_loader_config import ImageCaptionSDDataLoaderConfig from invoke_training.config.data.dataset_config import ImageCaptionJsonlDatasetConfig from ..dataset_fixtures import image_caption_jsonl # noqa: F401 def test_build_image_caption_sd_dataloader(image_caption_jsonl): # noqa: F811 """Smoke test of build_image_caption_sd_dataloader(...).""" config = ImageCaptionSDDataLoaderConfig( dataset=ImageCaptionJsonlDatasetConfig(jsonl_path=str(image_caption_jsonl)), ) data_loader = build_image_caption_sd_dataloader(config, 4) # The dataset has length 5, so the data loader should have 2 batches. assert len(data_loader) == math.ceil(5 / 4) example = next(iter(data_loader)) assert set(example.keys()) == {"image", "id", "caption", "original_size_hw", "crop_top_left_yx"} image = example["image"] assert image.shape == (4, 3, 512, 512) assert image.dtype == torch.float32 assert len(example["caption"]) == 4 original_size_hw = example["original_size_hw"] assert len(original_size_hw) == 4 assert len(original_size_hw[0]) == 2 crop_top_left_yx = example["crop_top_left_yx"] assert len(crop_top_left_yx) == 4 assert len(crop_top_left_yx[0]) == 2 def test_build_image_caption_sd_dataloader_with_masks(image_caption_jsonl): # noqa: F811 """Smoke test of build_image_caption_sd_dataloader(...).""" config = ImageCaptionSDDataLoaderConfig( dataset=ImageCaptionJsonlDatasetConfig(jsonl_path=str(image_caption_jsonl)), ) data_loader = build_image_caption_sd_dataloader(config, 4, use_masks=True) # The dataset has length 5, so the data loader should have 2 batches. assert len(data_loader) == math.ceil(5 / 4) example = next(iter(data_loader)) assert set(example.keys()) == {"image", "mask", "id", "caption", "original_size_hw", "crop_top_left_yx"} image = example["image"] assert image.shape == (4, 3, 512, 512) assert image.dtype == torch.float32 mask = example["mask"] assert mask.shape == (4, 1, 512, 512) assert mask.dtype == torch.float32 assert len(example["caption"]) == 4 original_size_hw = example["original_size_hw"] assert len(original_size_hw) == 4 assert len(original_size_hw[0]) == 2 crop_top_left_yx = example["crop_top_left_yx"] assert len(crop_top_left_yx) == 4 assert len(crop_top_left_yx[0]) == 2 ================================================ FILE: tests/invoke_training/_shared/data/data_loaders/test_image_pair_preference_sd_dataloader.py ================================================ import pytest import torch from invoke_training._shared.data.data_loaders.image_pair_preference_sd_dataloader import ( build_image_pair_preference_sd_dataloader, ) from invoke_training.pipelines._experimental.sd_dpo_lora.config import ( HFHubImagePairPreferenceDatasetConfig, ImagePairPreferenceSDDataLoaderConfig, ) @pytest.mark.skip( reason="No yuvalkirstain/pickapic_v2 dataset on HF Hub: https://huggingface.co/datasets/yuvalkirstain/pickapic_v2" ) def test_build_image_pair_preference_sd_dataloader(): """Smoke test of build_image_pair_preference_sd_dataloader(...).""" config = ImagePairPreferenceSDDataLoaderConfig(dataset=HFHubImagePairPreferenceDatasetConfig()) data_loader = build_image_pair_preference_sd_dataloader(config, 4) example = next(iter(data_loader)) assert set(example.keys()) == { "id", "image_0", "original_size_hw_0", "crop_top_left_yx_0", "prefer_0", "image_1", "original_size_hw_1", "crop_top_left_yx_1", "prefer_1", "caption", } for image_key in ["image_0", "image_1"]: image = example[image_key] assert image.shape == (4, 3, 512, 512) assert image.dtype == torch.float32 assert len(example["caption"]) == 4 for orig_size_key in ["original_size_hw_0", "original_size_hw_1"]: original_size_hw = example[orig_size_key] assert len(original_size_hw) == 4 assert len(original_size_hw[0]) == 2 for crop_key in ["crop_top_left_yx_0", "crop_top_left_yx_1"]: crop_top_left_yx = example[crop_key] assert len(crop_top_left_yx) == 4 assert len(crop_top_left_yx[0]) == 2 ================================================ FILE: tests/invoke_training/_shared/data/data_loaders/test_textual_inversion_sd_dataloader.py ================================================ import torch from invoke_training._shared.data.data_loaders.textual_inversion_sd_dataloader import ( build_textual_inversion_sd_dataloader, ) from invoke_training.config.data.data_loader_config import TextualInversionSDDataLoaderConfig from invoke_training.config.data.dataset_config import ImageCaptionJsonlDatasetConfig, ImageDirDatasetConfig from ..dataset_fixtures import ( image_caption_jsonl, # noqa: F401 image_dir, # noqa: F401 ) def test_build_textual_inversion_sd_dataloader(image_dir): # noqa: F811 """Smoke test of build_textual_inversion_sd_dataloader(...).""" config = TextualInversionSDDataLoaderConfig( dataset=ImageDirDatasetConfig(dataset_dir=str(image_dir)), caption_preset="object" ) data_loader = build_textual_inversion_sd_dataloader( config=config, placeholder_token="placeholder", batch_size=2, ) assert len(data_loader) == 3 # ceil(5 images / batch size 2) example = next(iter(data_loader)) assert set(example.keys()) == {"image", "id", "caption", "original_size_hw", "crop_top_left_yx"} image = example["image"] assert image.shape == (2, 3, 512, 512) assert image.dtype == torch.float32 assert len(example["caption"]) == 2 for caption in example["caption"]: assert "placeholder" in caption original_size_hw = example["original_size_hw"] assert len(original_size_hw) == 2 assert len(original_size_hw[0]) == 2 crop_top_left_yx = example["crop_top_left_yx"] assert len(crop_top_left_yx) == 2 assert len(crop_top_left_yx[0]) == 2 def test_build_textual_inversion_sd_dataloader_keep_original_captions(image_caption_jsonl): # noqa: F811 """Test the keep_original_captions=True option.""" config = TextualInversionSDDataLoaderConfig( dataset=ImageCaptionJsonlDatasetConfig(jsonl_path=str(image_caption_jsonl)), caption_templates=["{}"], keep_original_captions=True, ) data_loader = build_textual_inversion_sd_dataloader( config=config, placeholder_token="placeholder", batch_size=2, ) example = next(iter(data_loader)) assert set(example.keys()) == {"image", "id", "caption", "original_size_hw", "crop_top_left_yx"} assert len(example["caption"]) == 2 for caption in example["caption"]: assert caption.startswith("placeholder ") def test_build_textual_inversion_sd_dataloader_with_masks(image_caption_jsonl): # noqa: F811 """Test the use_masks=True option.""" config = TextualInversionSDDataLoaderConfig( dataset=ImageCaptionJsonlDatasetConfig(jsonl_path=str(image_caption_jsonl)), caption_templates=["{}"], ) data_loader = build_textual_inversion_sd_dataloader( config=config, placeholder_token="placeholder", batch_size=2, use_masks=True, ) example = next(iter(data_loader)) assert set(example.keys()) == {"image", "mask", "id", "caption", "original_size_hw", "crop_top_left_yx"} image = example["image"] assert image.shape == (2, 3, 512, 512) assert image.dtype == torch.float32 mask = example["mask"] assert mask.shape == (2, 1, 512, 512) assert mask.dtype == torch.float32 assert len(example["caption"]) == 2 for caption in example["caption"]: assert "placeholder" in caption original_size_hw = example["original_size_hw"] assert len(original_size_hw) == 2 assert len(original_size_hw[0]) == 2 crop_top_left_yx = example["crop_top_left_yx"] assert len(crop_top_left_yx) == 2 assert len(crop_top_left_yx[0]) == 2 ================================================ FILE: tests/invoke_training/_shared/data/dataset_fixtures.py ================================================ import numpy as np import PIL.Image import pytest from invoke_training._shared.data.datasets.image_pair_preference_dataset import ImagePairPreferenceDataset from invoke_training._shared.utils.jsonl import save_jsonl @pytest.fixture(scope="session") def image_dir(tmp_path_factory: pytest.TempPathFactory): """A fixture that populates a temp directory with some test images and returns the directory path. Note that the 'session' scope is used to share the same directory across all tests in a session, because it is costly to populate the directory. Refer to https://docs.pytest.org/en/7.4.x/how-to/tmp_path.html#the-tmp-path-factory-fixture for details on the use of tmp_path_factory. """ tmp_dir = tmp_path_factory.mktemp("dataset") for i in range(5): rgb_np = np.ones((128, 128, 3), dtype=np.uint8) rgb_pil = PIL.Image.fromarray(rgb_np) rgb_pil.save(tmp_dir / f"{i}.jpg") return tmp_dir @pytest.fixture(scope="session") def image_caption_dir(tmp_path_factory: pytest.TempPathFactory): """A fixture that populates a temp directory with some test images and caption files and returns the directory path. Note that the 'session' scope is used to share the same directory across all tests in a session, because it is costly to populate the directory. Refer to https://docs.pytest.org/en/7.4.x/how-to/tmp_path.html#the-tmp-path-factory-fixture for details on the use of tmp_path_factory. """ tmp_dir = tmp_path_factory.mktemp("dataset") for i in range(5): rgb_np = np.ones((128, 128, 3), dtype=np.uint8) rgb_pil = PIL.Image.fromarray(rgb_np) rgb_pil.save(tmp_dir / f"{i}.jpg") with open(tmp_dir / f"{i}.txt", "w") as f: f.write(f"caption {i}") return tmp_dir @pytest.fixture(scope="session") def image_caption_jsonl(tmp_path_factory: pytest.TempPathFactory): """A fixture that populates a temp directory with a ImageCaptionJsonlDataset and returns the jsonl file path. Note that the 'session' scope is used to share the same directory across all tests in a session, because it is costly to populate the directory. Refer to https://docs.pytest.org/en/7.4.x/how-to/tmp_path.html#the-tmp-path-factory-fixture for details on the use of tmp_path_factory. """ tmp_dir = tmp_path_factory.mktemp("dataset") masks_dir = tmp_dir / "masks" masks_dir.mkdir() data = [] for i in range(5): rgb_np = np.ones((128, 128, 3), dtype=np.uint8) rgb_pil = PIL.Image.fromarray(rgb_np) rgb_rel_path = f"{i}.jpg" rgb_pil.save(tmp_dir / rgb_rel_path) mask_np = np.ones((128, 128), dtype=np.uint8) mask_pil = PIL.Image.fromarray(mask_np).convert("L") mask_rel_path = f"masks/{i}.png" mask_pil.save(tmp_dir / mask_rel_path) data.append({"image": str(rgb_rel_path), "mask": str(mask_rel_path), "text": f"caption {i}"}) data_jsonl_path = tmp_dir / "data.jsonl" save_jsonl(data, data_jsonl_path) return data_jsonl_path @pytest.fixture(scope="session") def image_pair_preference_dir(tmp_path_factory: pytest.TempPathFactory): """A fixture that populates a temp directory with a mock dataset intended to be consumed by ImagePairPreferenceDataset, and returns the directory path. Note that the 'session' scope is used to share the same directory across all tests in a session, because it is costly to populate the directory. Refer to https://docs.pytest.org/en/7.4.x/how-to/tmp_path.html#the-tmp-path-factory-fixture for details on the use of tmp_path_factory. """ tmp_dir = tmp_path_factory.mktemp("dataset") prompts = ["mock prompt 1", "mock prompt 2"] metadata = [] for prompt_idx in range(len(prompts)): for set_idx in range(3): set_dir = tmp_dir / f"prompt-{prompt_idx:0>4}" / f"set-{set_idx:0>4}" set_dir.mkdir(parents=True) set_metadata_dict = {"prompt": prompts[prompt_idx]} for image_idx in range(2): rgb_np = np.ones((32, 32, 3), dtype=np.uint8) rgb_pil = PIL.Image.fromarray(rgb_np) image_path = set_dir / f"image-{image_idx}.jpg" rgb_pil.save(image_path) set_metadata_dict[f"image_{image_idx}"] = str(image_path.relative_to(tmp_dir)) set_metadata_dict[f"prefer_{image_idx}"] = image_idx == 0 # Always prefer image 0. metadata.append(set_metadata_dict) ImagePairPreferenceDataset.save_metadata(metadata=metadata, dataset_dir=tmp_dir) return tmp_dir ================================================ FILE: tests/invoke_training/_shared/data/datasets/__init__.py ================================================ ================================================ FILE: tests/invoke_training/_shared/data/datasets/test_hf_image_caption_dataset.py ================================================ from pathlib import Path import numpy as np import PIL import pytest from PIL import Image from invoke_training._shared.data.datasets.hf_image_caption_dataset import ( HFImageCaptionDataset, ) from invoke_training._shared.data.utils.resolution import Resolution from invoke_training._shared.utils.jsonl import save_jsonl ################################################ # Tests for HFImageCaptionDataset.from_dir(...) ################################################ def create_hf_imagefolder_dataset(tmp_dir: Path, num_images: int): """Construct a mock Hugging Face imagefolder dataset in a temporary directory. Args: tmp_dir (Path): The temporary directory where the mock dataset will be created. num_images (int): The number of mock images to include in the dataset. """ # Construct mock images and save them to disk. rel_img_paths = [] for i in range(num_images): rgb_np = np.ones((128, 128, 3), dtype=np.uint8) rgb_pil = Image.fromarray(rgb_np) rel_img_path = f"{i}.jpg" rel_img_paths.append(rel_img_path) rgb_pil.save(tmp_dir / rel_img_path) # Construct a mock metadata dict. metadata = [] for rel_img_path in rel_img_paths: metadata.append({"file_name": rel_img_path, "text": f"Caption for {rel_img_path}"}) # Write the metadata.jsonl to disk. metadata_path = tmp_dir / "metadata.jsonl" save_jsonl(metadata, metadata_path) @pytest.fixture(scope="session") def hf_imagefolder_dir(tmp_path_factory: pytest.TempPathFactory): """A fixture that prepares a temp directory with a mock Hugging Face imagefolder dataset and returns the directory path. Note that the 'session' scope is used to share the same directory across all tests in a session, because it is costly to populate the directory. Refer to https://docs.pytest.org/en/7.4.x/how-to/tmp_path.html#the-tmp-path-factory-fixture for details on the use of tmp_path_factory. """ tmp_dir = tmp_path_factory.mktemp("dataset") create_hf_imagefolder_dataset(tmp_dir, 5) return tmp_dir @pytest.fixture() def hf_dir_dataset(hf_imagefolder_dir: Path): return HFImageCaptionDataset.from_dir(str(hf_imagefolder_dir)) def test_hf_dir_image_caption_dataset_bad_image_column(hf_imagefolder_dir: Path): """Test that a ValueError is raised if HFImageCaptionDataset is initialized with an `image_column` that does not exist. """ with pytest.raises(ValueError): _ = HFImageCaptionDataset.from_dir(str(hf_imagefolder_dir), image_column="does_not_exist") def test_hf_dir_image_caption_dataset_bad_caption_column(hf_imagefolder_dir: Path): """Test that a ValueError is raised if HFImageCaptionDataset is initialized with a `caption_column` that does not exist. """ with pytest.raises(ValueError): _ = HFImageCaptionDataset.from_dir(str(hf_imagefolder_dir), caption_column="does_not_exist") def test_hf_dir_image_caption_dataset_len(hf_dir_dataset: HFImageCaptionDataset): """Test the behaviour of HFImageCaptionDataset.__len__().""" assert len(hf_dir_dataset) == 5 def test_hf_dir_image_caption_dataset_index_error(hf_dir_dataset: HFImageCaptionDataset): """Test that an IndexError is raised if a dataset element is accessed with an index that is out-of-bounds.""" with pytest.raises(IndexError): _ = hf_dir_dataset[1000] def test_hf_dir_image_caption_dataset_getitem(hf_dir_dataset: HFImageCaptionDataset): """Test that HFImageCaptionDataset.__getitem__(...) returns a valid example.""" example = hf_dir_dataset[0] assert set(example.keys()) == {"image", "caption", "id"} assert isinstance(example["image"], PIL.Image.Image) assert example["image"].mode == "RGB" assert isinstance(example["caption"], str) assert example["id"] == 0 def test_hf_dir_image_caption_dataset_get_image_dimensions(hf_dir_dataset: HFImageCaptionDataset): """Test HFImageCaptionDataset.get_image_dimensions().""" image_dims = hf_dir_dataset.get_image_dimensions() assert len(image_dims) == 5 for image_dim in image_dims: assert image_dim == Resolution(128, 128) ################################################ # Tests for HFImageCaptionDataset.from_hub(...) ################################################ @pytest.mark.skip(reason="The lambdalabs/pokemon-blip-captions dataset is no longer available.") @pytest.mark.loads_model def test_hf_hub_image_caption_dataset_bad_image_column(): """Test that a ValueError is raised if HFImageCaptionDataset is initialized with an `image_column` that does not exist. """ with pytest.raises(ValueError): _ = HFImageCaptionDataset.from_hub( "lambdalabs/pokemon-blip-captions", hf_load_dataset_kwargs={"revision": "8b762e1dac1b31d60e01ee8f08a9d8a232b59e17"}, image_column="does_not_exist", ) @pytest.mark.skip(reason="The lambdalabs/pokemon-blip-captions dataset is no longer available.") @pytest.mark.loads_model def test_hf_hub_image_caption_dataset_bad_caption_column(): """Test that a ValueError is raised if HFImageCaptionDataset is initialized with a `caption_column` that does not exist. """ with pytest.raises(ValueError): _ = HFImageCaptionDataset.from_hub( "lambdalabs/pokemon-blip-captions", hf_load_dataset_kwargs={"revision": "8b762e1dac1b31d60e01ee8f08a9d8a232b59e17"}, caption_column="does_not_exist", ) @pytest.fixture def hf_hub_dataset(): return HFImageCaptionDataset.from_hub( "lambdalabs/pokemon-blip-captions", hf_load_dataset_kwargs={"revision": "8b762e1dac1b31d60e01ee8f08a9d8a232b59e17"}, ) @pytest.mark.skip(reason="The lambdalabs/pokemon-blip-captions dataset is no longer available.") @pytest.mark.loads_model def test_hf_hub_image_caption_dataset_index_error(hf_hub_dataset: HFImageCaptionDataset): """Test that an IndexError is raised if a dataset element is accessed with an index that is out-of-bounds.""" with pytest.raises(IndexError): _ = hf_hub_dataset[1000] @pytest.mark.skip(reason="The lambdalabs/pokemon-blip-captions dataset is no longer available.") @pytest.mark.loads_model def test_hf_hub_image_caption_dataset_len(hf_hub_dataset: HFImageCaptionDataset): """Test the behaviour of HFImageCaptionDataset.__len__().""" # Expected dataset length was checked manually here: # https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions assert len(hf_hub_dataset) == 833 @pytest.mark.skip(reason="The lambdalabs/pokemon-blip-captions dataset is no longer available.") @pytest.mark.loads_model def test_hf_hub_image_caption_dataset_getitem(hf_hub_dataset: HFImageCaptionDataset): """Test that HFImageCaptionDataset.__getitem__(...) returns a valid example.""" example = hf_hub_dataset[0] assert set(example.keys()) == {"image", "caption", "id"} assert isinstance(example["image"], PIL.Image.Image) assert example["image"].mode == "RGB" assert isinstance(example["caption"], str) assert example["id"] == 0 @pytest.mark.skip(reason="The lambdalabs/pokemon-blip-captions dataset is no longer available.") @pytest.mark.loads_model def test_hf_hub_image_caption_dataset_get_image_dimensions(hf_hub_dataset: HFImageCaptionDataset): """Test HFImageCaptionDataset.get_image_dimensions().""" image_dims = hf_hub_dataset.get_image_dimensions() # This is just a smoke test. We don't actually check that the dimensions are correct. assert len(image_dims) == 833 ================================================ FILE: tests/invoke_training/_shared/data/datasets/test_hf_image_pair_preference_dataset.py ================================================ import pytest from datasets import VerificationMode from PIL.Image import Image from invoke_training._shared.data.datasets.hf_image_pair_preference_dataset import HFImagePairPreferenceDataset @pytest.mark.loads_model def test_hf_hub_image_caption_dataset_getitem(): """Test that HFImagePairPreferenceDataset.__getitem__(...) returns a valid example.""" # HACK(ryand): This funky configuration is done so that we just download a small slice of the very large # 'yuvalkirstain/pickapic_v2' dataset. dataset = HFImagePairPreferenceDataset.from_hub( "yuvalkirstain/pickapic_v2", split="validation_unique", hf_load_dataset_kwargs={ "data_files": { "validation_unique": "data/validation_unique-00000-of-00001-33ead111845fc9c4.parquet", }, # Disable checks so that it doesn't complain that I haven't downloaded the other splits. "verification_mode": VerificationMode.NO_CHECKS, }, ) example = dataset[0] assert set(example.keys()) == {"id", "image_0", "image_1", "prefer_0", "prefer_1", "caption"} assert example["id"] == 0 assert isinstance(example["image_0"], Image) assert example["image_0"].mode == "RGB" assert isinstance(example["image_1"], Image) assert example["image_1"].mode == "RGB" assert isinstance(example["prefer_0"], bool) assert isinstance(example["prefer_1"], bool) # The following is not always true, but is usually true. assert example["prefer_0"] != example["prefer_1"] assert isinstance(example["caption"], str) @pytest.mark.loads_model def test_hf_hub_image_caption_dataset_len(): """Test that HFImagePairPreferenceDataset.__len__(...) returns the correct value.""" # HACK(ryand): This funky configuration is done so that we just download a small slice of the very large # 'yuvalkirstain/pickapic_v2' dataset. dataset = HFImagePairPreferenceDataset.from_hub( "yuvalkirstain/pickapic_v2", skip_no_preference=False, split="validation_unique", hf_load_dataset_kwargs={ "data_files": { "validation_unique": "data/validation_unique-00000-of-00001-33ead111845fc9c4.parquet", }, # Disable checks so that it doesn't complain that I haven't downloaded the other splits. "verification_mode": VerificationMode.NO_CHECKS, }, ) assert len(dataset) == 500 @pytest.mark.loads_model def test_hf_hub_image_caption_dataset_skip_no_preference_len(): """Test the HFImagePairPreferenceDataset skip_no_preference parameter.""" # HACK(ryand): This funky configuration is done so that we just download a small slice of the very large # 'yuvalkirstain/pickapic_v2' dataset. dataset = HFImagePairPreferenceDataset.from_hub( "yuvalkirstain/pickapic_v2", skip_no_preference=True, split="validation_unique", hf_load_dataset_kwargs={ "data_files": { "validation_unique": "data/validation_unique-00000-of-00001-33ead111845fc9c4.parquet", }, # Disable checks so that it doesn't complain that I haven't downloaded the other splits. "verification_mode": VerificationMode.NO_CHECKS, }, ) assert len(dataset) == 429 ================================================ FILE: tests/invoke_training/_shared/data/datasets/test_image_caption_dir_dataset.py ================================================ from pathlib import Path import PIL.Image import pytest from invoke_training._shared.data.datasets.image_caption_dir_dataset import ImageCaptionDirDataset from ..dataset_fixtures import image_caption_dir # noqa: F401 def test_image_caption_dir_dataset_len(image_caption_dir): # noqa: F811 dataset = ImageCaptionDirDataset(str(image_caption_dir)) assert len(dataset) == 5 def test_image_caption_dir_dataset_getitem(image_caption_dir): # noqa: F811 dataset = ImageCaptionDirDataset(str(image_caption_dir)) example = dataset[0] assert set(example.keys()) == {"image", "id", "caption"} assert isinstance(example["image"], PIL.Image.Image) assert example["image"].mode == "RGB" assert example["id"] == "0" assert example["caption"] == "caption 0" def test_image_caption_dir_dataset_keep_in_memory(image_caption_dir): # noqa: F811 dataset = ImageCaptionDirDataset(str(image_caption_dir), keep_in_memory=True) example = dataset[0] assert set(example.keys()) == {"image", "id", "caption"} assert isinstance(example["image"], PIL.Image.Image) assert example["image"].mode == "RGB" assert example["id"] == "0" assert example["caption"] == "caption 0" def test_image_caption_dir_dataset_get_image_dimensions(image_caption_dir): # noqa: F811 dataset = ImageCaptionDirDataset(str(image_caption_dir)) image_dims = dataset.get_image_dimensions() assert len(image_dims) == len(dataset) def test_image_caption_dir_dataset_missing_caption_file(tmp_path: Path): # noqa: F811 # Create a directory with an image but no caption file. with open(tmp_path / "0.jpg", "w"): pass with pytest.raises(Exception, match=r"The following expected caption files are missing: \['.*0.txt'\]"): ImageCaptionDirDataset(str(tmp_path)) ================================================ FILE: tests/invoke_training/_shared/data/datasets/test_image_caption_jsonl_dataset.py ================================================ import shutil from pathlib import Path import PIL.Image from invoke_training._shared.data.datasets.image_caption_jsonl_dataset import ImageCaptionJsonlDataset from invoke_training._shared.utils.jsonl import load_jsonl from ..dataset_fixtures import image_caption_jsonl # noqa: F401 def test_image_caption_jsonl_dataset_len(image_caption_jsonl): # noqa: F811 dataset = ImageCaptionJsonlDataset(str(image_caption_jsonl)) assert len(dataset) == 5 def test_image_caption_jsonl_dataset_getitem(image_caption_jsonl): # noqa: F811 dataset = ImageCaptionJsonlDataset(str(image_caption_jsonl)) example = dataset[0] assert set(example.keys()) == {"image", "id", "caption", "mask"} assert isinstance(example["image"], PIL.Image.Image) assert example["image"].mode == "RGB" assert example["id"] == "0" assert example["caption"] == "caption 0" assert isinstance(example["mask"], PIL.Image.Image) assert example["mask"].mode == "L" def test_image_caption_jsonl_dataset_keep_in_memory(image_caption_jsonl): # noqa: F811 dataset = ImageCaptionJsonlDataset(str(image_caption_jsonl), keep_in_memory=True) example = dataset[0] assert set(example.keys()) == {"image", "id", "caption", "mask"} assert isinstance(example["image"], PIL.Image.Image) assert example["image"].mode == "RGB" assert example["id"] == "0" assert example["caption"] == "caption 0" assert isinstance(example["mask"], PIL.Image.Image) assert example["mask"].mode == "L" # Confirm that accessing the same example again returns a shallow copy of the original example. # In other words, modifying the returned dict should not modify the cached example, but the same image should be # returned. same_example = dataset[0] assert same_example is not example assert same_example["image"] is example["image"] def test_image_caption_jsonl_dataset_get_image_dimensions(image_caption_jsonl): # noqa: F811 dataset = ImageCaptionJsonlDataset(str(image_caption_jsonl)) image_dims = dataset.get_image_dimensions() assert len(image_dims) == len(dataset) def test_image_caption_jsonl_dataset_save_jsonl(image_caption_jsonl, tmp_path: Path): # noqa: F811 # Create a copy of the image_caption_jsonl file to avoid modifying the original file. image_caption_jsonl_copy = tmp_path / "test.jsonl" shutil.copy(image_caption_jsonl, image_caption_jsonl_copy) # Load the dataset from the copied jsonl file. dataset = ImageCaptionJsonlDataset(str(image_caption_jsonl)) # Save the dataset to a new jsonl file. dataset.save_jsonl() # Verify that the roundtrip was successful. assert image_caption_jsonl != image_caption_jsonl_copy original_jsonl = load_jsonl(image_caption_jsonl) roundtrip_jsonl = load_jsonl(image_caption_jsonl_copy) assert original_jsonl == roundtrip_jsonl ================================================ FILE: tests/invoke_training/_shared/data/datasets/test_image_dir_dataset.py ================================================ import PIL.Image from invoke_training._shared.data.datasets.image_dir_dataset import ImageDirDataset from ..dataset_fixtures import image_dir # noqa: F401 def test_image_dir_dataset_len(image_dir): # noqa: F811 dataset = ImageDirDataset(str(image_dir)) assert len(dataset) == 5 def test_image_dir_dataset_getitem(image_dir): # noqa: F811 dataset = ImageDirDataset(str(image_dir)) example = dataset[0] assert set(example.keys()) == {"image", "id"} assert isinstance(example["image"], PIL.Image.Image) assert example["image"].mode == "RGB" assert example["id"] == "0" def test_image_dir_dataset_keep_in_memory(image_dir): # noqa: F811 dataset = ImageDirDataset(str(image_dir), keep_in_memory=True) example = dataset[0] assert set(example.keys()) == {"image", "id"} assert isinstance(example["image"], PIL.Image.Image) assert example["image"].mode == "RGB" assert example["id"] == "0" # Confirm that accessing the same example again returns a shallow copy of the original example. # In other words, modifying the returned dict should not modify the cached example, but the same image should be # returned. same_example = dataset[0] assert same_example is not example assert same_example["image"] is example["image"] def test_image_dir_dataset_get_image_dimensions(image_dir): # noqa: F811 dataset = ImageDirDataset(str(image_dir)) image_dims = dataset.get_image_dimensions() assert len(image_dims) == len(dataset) ================================================ FILE: tests/invoke_training/_shared/data/datasets/test_image_pair_preference_dataset.py ================================================ import PIL.Image from invoke_training._shared.data.datasets.image_pair_preference_dataset import ImagePairPreferenceDataset from ..dataset_fixtures import image_pair_preference_dir # noqa: F401 def test_image_dir_dataset_len(image_pair_preference_dir): # noqa: F811 dataset = ImagePairPreferenceDataset(str(image_pair_preference_dir)) assert len(dataset) == 6 def test_image_dir_dataset_getitem(image_pair_preference_dir): # noqa: F811 dataset = ImagePairPreferenceDataset(str(image_pair_preference_dir)) example = dataset[0] assert set(example.keys()) == {"id", "image_0", "image_1", "caption", "prefer_0", "prefer_1"} assert example["id"] == "0" assert isinstance(example["image_0"], PIL.Image.Image) assert example["image_0"].mode == "RGB" assert isinstance(example["image_1"], PIL.Image.Image) assert example["image_1"].mode == "RGB" assert example["prefer_0"] assert not example["prefer_1"] ================================================ FILE: tests/invoke_training/_shared/data/datasets/test_transform_dataset.py ================================================ import unittest.mock from invoke_training._shared.data.datasets.transform_dataset import TransformDataset def test_transform_dataset_len(): """Test the TransformDataset len() function.""" mock_dataset = unittest.mock.MagicMock() mock_dataset.__len__.return_value = 5 dataset = TransformDataset(mock_dataset, []) assert len(dataset) == 5 def test_transform_dataset_getitem(): """Test the TransformDataset __getitem__() function.""" field1 = 1 field2 = "2" base_example = {"field1": field1} mock_dataset = unittest.mock.MagicMock() mock_dataset.__getitem__.return_value = base_example def mock_transform(example): example["field2"] = field2 return example dataset = TransformDataset(mock_dataset, [mock_transform]) out_example = dataset[0] assert out_example["field1"] == field1 assert out_example["field2"] == field2 ================================================ FILE: tests/invoke_training/_shared/data/samplers/__init__.py ================================================ ================================================ FILE: tests/invoke_training/_shared/data/samplers/test_aspect_ratio_bucket_batch_sampler.py ================================================ from invoke_training._shared.data.samplers.aspect_ratio_bucket_batch_sampler import ( AspectRatioBucketBatchSampler, ) from invoke_training._shared.data.utils.aspect_ratio_bucket_manager import AspectRatioBucketManager from invoke_training._shared.data.utils.resolution import Resolution def assert_shuffled_samples_match(samples_1, samples_2): """Utility function to assert that two batch sampler outputs are equivalent aside from having been shuffled.""" # Same number of batches. assert len(samples_1) == len(samples_2) # Same total number of examples. assert sum([len(b) for b in samples_1]) == sum([len(b) for b in samples_2]) # Same set of examples. assert {x for batch in samples_1 for x in batch} == {x for batch in samples_2 for x in batch} def test_aspect_ratio_bucket_batch_sampler(): """Basic test of AspectRatioBucketBatchSampler.""" sampler = AspectRatioBucketBatchSampler( buckets={Resolution(256, 768): [1, 3, 5], Resolution(512, 512): [4], Resolution(768, 256): [0, 2]}, batch_size=2, shuffle=False, seed=None, ) assert list(sampler) == [[1, 3], [5], [4], [0, 2]] def test_aspect_ratio_bucket_batch_sampler_len(): """Basic test of AspectRatioBucketBatchSampler len(...) function.""" sampler = AspectRatioBucketBatchSampler( buckets={Resolution(256, 768): [1, 3, 5], Resolution(512, 512): [4], Resolution(768, 256): [0, 2]}, batch_size=2, shuffle=False, seed=None, ) assert len(sampler) == len(list(sampler)) def test_aspect_ratio_bucket_batch_sampler_from_image_sizes(): """Test AspectRatioBucketBatchSampler when initialized with AspectRatioBucketBatchSampler.from_image_size(...).""" # Configure bucket_manager to have the following aspect ratio buckets: # (256, 1024), (256, 768), (512, 512), (768, 256), (1024, 768) bucket_manager = AspectRatioBucketManager.from_constraints( target_resolution=512, start_dim=256, end_dim=768, divisible_by=256 ) image_sizes = [ Resolution(256, 768), # Bucket 1 (256, 768) Resolution(512, 512), # Bucket 2 (512, 512) Resolution(768, 256), # Bucket 3 (768, 256) Resolution(264, 768), # Bucket 1 (256, 768) Resolution(272, 768), # Bucket 1 (256, 768) Resolution(768, 264), # Bucket 3 (768, 256) ] sampler = AspectRatioBucketBatchSampler.from_image_sizes( bucket_manager=bucket_manager, image_sizes=image_sizes, batch_size=2, shuffle=False ) assert list(sampler) == [[0, 3], [4], [1], [2, 5]] def test_aspect_ratio_bucket_batch_sampler_shuffle(): """Test AspectRatioBucketBatchSampler shuffle behavior.""" buckets = {Resolution(256, 512): [1, 3, 5, 6, 7], Resolution(512, 512): [4], Resolution(512, 256): [0, 2]} batch_size = 2 unshuffled_sampler = AspectRatioBucketBatchSampler(buckets=buckets, batch_size=batch_size, shuffle=False, seed=None) shuffled_sampler = AspectRatioBucketBatchSampler(buckets=buckets, batch_size=batch_size, shuffle=True, seed=None) unshuffled_samples = list(unshuffled_sampler) shuffled_samples = list(shuffled_sampler) assert_shuffled_samples_match(shuffled_samples, unshuffled_samples) # Not equal, because one is shuffled. assert shuffled_samples != unshuffled_samples def test_aspect_ratio_bucket_batch_sampler_seed(): """Test AspectRatioBucketBatchSampler seed behavior.""" buckets = {Resolution(256, 512): [1, 3, 5, 6, 7], Resolution(512, 512): [4], Resolution(512, 256): [0, 2]} batch_size = 2 base_sampler = AspectRatioBucketBatchSampler(buckets=buckets, batch_size=batch_size, shuffle=True, seed=1) same_seed_sampler = AspectRatioBucketBatchSampler(buckets=buckets, batch_size=batch_size, shuffle=True, seed=1) diff_seed_sampler = AspectRatioBucketBatchSampler(buckets=buckets, batch_size=batch_size, shuffle=True, seed=2) base_samples = list(base_sampler) same_seed_samples = list(same_seed_sampler) diff_seed_samples = list(diff_seed_sampler) # Samples generated with the same seed should match exactly. assert base_samples == same_seed_samples # Samples generated with different seeds should match, except for the example ordering. assert_shuffled_samples_match(base_samples, diff_seed_samples) assert base_samples != diff_seed_samples ================================================ FILE: tests/invoke_training/_shared/data/samplers/test_batch_offset_sampler.py ================================================ from torch.utils.data.sampler import BatchSampler, SequentialSampler from invoke_training._shared.data.samplers.batch_offset_sampler import BatchOffsetSampler def test_batch_offset_sampler(): """Test that the BatchOffsetSampler yields the correct sequence of values.""" sequential_sampler = SequentialSampler([0] * 5) batch_sampler = BatchSampler(sequential_sampler, batch_size=2, drop_last=False) batch_offset_sampler = BatchOffsetSampler(sampler=batch_sampler, offset=10) assert list(batch_offset_sampler) == [[10, 11], [12, 13], [14]] # Assert that it can be iterated multiple times. assert list(batch_offset_sampler) == [[10, 11], [12, 13], [14]] def test_batch_offset_sampler_len(): """Test the BatchOffsetSampler len() function.""" sequential_sampler = SequentialSampler([0] * 5) batch_sampler = BatchSampler(sequential_sampler, batch_size=2, drop_last=False) batch_offset_sampler = BatchOffsetSampler(sampler=batch_sampler, offset=10) assert len(batch_offset_sampler) == 3 ================================================ FILE: tests/invoke_training/_shared/data/samplers/test_concat_sampler.py ================================================ from invoke_training._shared.data.samplers.concat_sampler import ConcatSampler def test_concat_sampler(): """Test that the ConcatSampler yields the correct sequence.""" sampler_1 = [0, 1, 2, 3] sampler_2 = [4, 5, 6] sampler_3 = [7, 8, 9, 10, 11, 12] sampler = ConcatSampler([sampler_1, sampler_2, sampler_3]) samples = list(sampler) assert samples == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] def test_concat_sampler_batches(): """Test that the ConcatSampler yields the correct sequence with batch samplers.""" sampler_1 = [[0, 1, 2], [3, 4, 5], [6]] sampler_2 = [[7, 8], [9]] sampler_3 = [[10, 11, 12], [13, 14, 15], [16, 17, 18], [19, 20, 21]] sampler = ConcatSampler([sampler_1, sampler_2, sampler_3]) samples = list(sampler) assert samples == [[0, 1, 2], [3, 4, 5], [6], [7, 8], [9], [10, 11, 12], [13, 14, 15], [16, 17, 18], [19, 20, 21]] def test_concat_sampler_len(): """Test the ConcatSampler len() function.""" sampler_1 = [0, 1, 2, 3] sampler_2 = [4, 5, 6] sampler_3 = [7, 8, 9, 10, 11, 12] sampler = ConcatSampler([sampler_1, sampler_2, sampler_3]) assert len(sampler) == 13 ================================================ FILE: tests/invoke_training/_shared/data/samplers/test_interleaved_sampler.py ================================================ from invoke_training._shared.data.samplers.interleaved_sampler import InterleavedSampler def test_interleaved_sampler(): """Test that the InterleavedSampler yields the correct sequence.""" sampler_1 = [0, 1, 2, 3] sampler_2 = [4, 5, 6] sampler_3 = [7, 8, 9, 10, 11, 12] sampler = InterleavedSampler([sampler_1, sampler_2, sampler_3]) samples = list(sampler) assert samples == [0, 4, 7, 1, 5, 8, 2, 6, 9] def test_interleaved_sampler_batches(): """Test that the InterleavedSampler yields the correct sequence with batch samplers.""" sampler_1 = [[0, 1, 2], [3, 4, 5], [6]] sampler_2 = [[7, 8], [9]] sampler_3 = [[10, 11, 12], [13, 14, 15], [16, 17, 18], [19, 20, 21]] sampler = InterleavedSampler([sampler_1, sampler_2, sampler_3]) samples = list(sampler) assert samples == [[0, 1, 2], [7, 8], [10, 11, 12], [3, 4, 5], [9], [13, 14, 15]] def test_interleaved_sampler_len(): """Test the InterleavedSampler len() function.""" sampler_1 = [0, 1, 2, 3] sampler_2 = [4, 5] sampler_3 = [7, 8, 9, 10, 11, 12] sampler = InterleavedSampler([sampler_1, sampler_2, sampler_3]) assert len(sampler) == 2 * 3 ================================================ FILE: tests/invoke_training/_shared/data/samplers/test_offset_sampler.py ================================================ from torch.utils.data.sampler import SequentialSampler from invoke_training._shared.data.samplers.offset_sampler import OffsetSampler def test_offset_sampler(): """Test that the OffsetSampler yields the correct sequence of values.""" sequential_sampler = SequentialSampler([0] * 5) offset_sampler = OffsetSampler(sampler=sequential_sampler, offset=10) assert list(offset_sampler) == list(range(10, 15)) # Assert that it can be iterated multiple times. assert list(offset_sampler) == list(range(10, 15)) def test_offset_sampler_len(): """Test the OffsetSampler len() function.""" sequential_sampler = SequentialSampler([0] * 5) offset_sampler = OffsetSampler(sampler=sequential_sampler, offset=10) assert len(offset_sampler) == 5 ================================================ FILE: tests/invoke_training/_shared/data/transforms/__init__.py ================================================ ================================================ FILE: tests/invoke_training/_shared/data/transforms/test_caption_prefix_transform.py ================================================ from invoke_training._shared.data.transforms.caption_prefix_transform import CaptionPrefixTransform def test_caption_prefix_transform(): tf = CaptionPrefixTransform(caption_field_name="caption", prefix="prefix ") in_example = {"caption": "original caption", "other": 2} out_example = tf(in_example) assert out_example == {"caption": "prefix original caption", "other": 2} ================================================ FILE: tests/invoke_training/_shared/data/transforms/test_concat_fields_transform.py ================================================ from invoke_training._shared.data.transforms.concat_fields_transform import ConcatFieldsTransform def test_caption_prefix_transform(): tf = ConcatFieldsTransform(src_field_names=["caption", "caption_2"], dst_field_name="caption", separator=", ") in_example = {"caption": "original caption", "caption_2": "another caption", "other": 2} out_example = tf(in_example) assert out_example == {"caption": "original caption, another caption", "caption_2": "another caption", "other": 2} ================================================ FILE: tests/invoke_training/_shared/data/transforms/test_constant_field_transform.py ================================================ from invoke_training._shared.data.transforms.constant_field_transform import ConstantFieldTransform def test_constant_field_transform(): tf = ConstantFieldTransform("test_field", 1) in_example = {"existing": 2} out_example = tf(in_example) assert out_example == {"existing": 2, "test_field": 1} ================================================ FILE: tests/invoke_training/_shared/data/transforms/test_drop_field_transform.py ================================================ from invoke_training._shared.data.transforms.drop_field_transform import DropFieldTransform def test_drop_field_transform(): tf = DropFieldTransform("drop") in_example = {"keep": 1, "drop": 2} out_example = tf(in_example) assert out_example == {"keep": 1} ================================================ FILE: tests/invoke_training/_shared/data/transforms/test_load_cache_transform.py ================================================ import unittest.mock import torch from invoke_training._shared.data.transforms.load_cache_transform import LoadCacheTransform def test_load_cache_transform(): cached_tensor = torch.Tensor([1.0, 2.0, 3.0]) mock_cache = unittest.mock.MagicMock() mock_cache.load.return_value = {"cached_tensor": cached_tensor} tf = LoadCacheTransform( cache=mock_cache, cache_key_field="cache_key", cache_field_to_output_field={"cached_tensor": "output"} ) in_example = {"cache_key": 1} out_example = tf(in_example) mock_cache.load.assert_called_once_with(1) assert out_example["output"] is cached_tensor ================================================ FILE: tests/invoke_training/_shared/data/transforms/test_sd_image_transform.py ================================================ import unittest.mock import numpy as np import pytest import torch from PIL import Image from invoke_training._shared.data.transforms.sd_image_transform import SDImageTransform from invoke_training._shared.data.utils.aspect_ratio_bucket_manager import AspectRatioBucketManager from invoke_training._shared.data.utils.resolution import Resolution def denormalize_image(img: np.ndarray) -> np.ndarray: """Convert a normalized CxHxW image in range [-1.0, 1.0] to a HxWxC image in the range [0, 255]. Args: img (np.ndarray): Image to denormalize. Returns: np.ndarray: Result image. """ # Convert back to range [0, 1.0]. img = img * 0.5 + 0.5 # Convert back to range [0, 255]. img *= 255 # Move channel axis from first dimension to last dimension. img = np.moveaxis(img, 0, -1) return img def denormalize_mask(mask: np.ndarray) -> np.ndarray: """Convert a normalized CxHxW mask in range [0.0, 1.0] to a HxW mask in the range [0, 255].""" # Convert back to range [0, 255]. mask *= 255 # Squeeze the channel dimension. mask = mask.squeeze(0) return mask def test_sd_image_transform_resolution(): """Test that SDImageTransform resizes and crops to the target resolution, and correctly sets original_size_hw.""" in_image_np = np.ones((256, 128, 3), dtype=np.uint8) in_image_pil = Image.fromarray(in_image_np) in_mask_np = np.ones((256, 128), dtype=np.uint8) in_mask_pil = Image.fromarray(in_mask_np) resolution = Resolution(768, 512) tf = SDImageTransform( image_field_names=["image", "mask"], fields_to_normalize_to_range_minus_one_to_one=["image"], resolution=resolution, ) out_example = tf({"image": in_image_pil, "mask": in_mask_pil}) out_image = out_example["image"] assert isinstance(out_image, torch.Tensor) assert out_image.shape == (3, resolution.height, resolution.width) out_mask = out_example["mask"] assert isinstance(out_mask, torch.Tensor) assert out_mask.shape == (1, resolution.height, resolution.width) original_size_hw = out_example["original_size_hw"] assert original_size_hw == (256, 128) def test_sd_image_transform_without_mask(): """Test that SDImageTransform works correctly when no mask is provided.""" in_image_np = np.ones((256, 128, 3), dtype=np.uint8) in_image_pil = Image.fromarray(in_image_np) resolution = Resolution(768, 512) tf = SDImageTransform( image_field_names=["image"], fields_to_normalize_to_range_minus_one_to_one=["image"], resolution=resolution, ) # No mask is provided. out_example = tf({"image": in_image_pil}) out_image = out_example["image"] assert isinstance(out_image, torch.Tensor) assert out_image.shape == (3, resolution.height, resolution.width) original_size_hw = out_example["original_size_hw"] assert original_size_hw == (256, 128) def test_sd_image_transform_range(): """Test that SDImageTransform normalizes the image to the range [-1.0, 1.0], and the mask to the range [0.0, 1.0]. """ resolution = 128 in_image_np = np.zeros((resolution, resolution, 3), dtype=np.uint8) in_image_np[0, 0, :] = 255 # Image contains one pixel with value 255, and the rest are zeros. in_image_pil = Image.fromarray(in_image_np) in_mask_np = np.zeros((resolution, resolution), dtype=np.uint8) in_mask_np[0, 0] = 255 # Mask contains one pixel with value 255, and the rest are zeros. in_mask_pil = Image.fromarray(in_mask_np) tf = SDImageTransform( image_field_names=["image", "mask"], fields_to_normalize_to_range_minus_one_to_one=["image"], resolution=resolution, ) out_example = tf({"image": in_image_pil, "mask": in_mask_pil}) out_image = out_example["image"] out_np = np.array(out_image) assert np.allclose(out_np[:, 0, 0], 1.0) assert np.allclose(out_np[:, 1:, 1:], -1.0) out_mask = out_example["mask"] out_np = np.array(out_mask) assert np.allclose(out_np[0, 0, 0], 1.0) assert np.allclose(out_np[0, 1:, 1:], 0.0) def test_sd_image_transform_center_crop(): """Test SDImageTransform center cropping.""" # Input image is 9 x 5. in_image_np = np.arange(9 * 5 * 3, dtype=np.uint8).reshape((9, 5, 3)) in_image_pil = Image.fromarray(np.copy(in_image_np)) mask_image_np = np.arange(9 * 5, dtype=np.uint8).reshape((9, 5)) mask_image_pil = Image.fromarray(np.copy(mask_image_np)) # The target resolution is 3x5 (with center cropping). tf = SDImageTransform( image_field_names=["image", "mask"], fields_to_normalize_to_range_minus_one_to_one=["image"], resolution=(3, 5), center_crop=True, ) out_example = tf({"image": in_image_pil, "mask": mask_image_pil}) # Verify that the correct region of the image was cropped. out_image = out_example["image"] out_image_np = np.array(out_image) assert np.allclose(denormalize_image(out_image_np), in_image_np[3:-3, :, :]) assert out_example["crop_top_left_yx"] == (3, 0) # Verify that the correct region of the mask was cropped. out_mask = out_example["mask"] out_mask_np = np.array(out_mask) assert np.allclose(denormalize_mask(out_mask_np), mask_image_np[3:-3, :]) def test_sd_image_transform_random_crop(): """Test SDImageTransform random cropping.""" # Input image is 9 x 5. in_image_np = np.arange(9 * 5 * 3, dtype=np.uint8).reshape((9, 5, 3)) in_image_pil = Image.fromarray(np.copy(in_image_np)) mask_image_np = np.arange(9 * 5, dtype=np.uint8).reshape((9, 5)) mask_image_pil = Image.fromarray(np.copy(mask_image_np)) # The target resolution is 3x5 (with random cropping). resolution = Resolution(3, 5) tf = SDImageTransform( image_field_names=["image", "mask"], fields_to_normalize_to_range_minus_one_to_one=["image"], resolution=resolution, center_crop=False, ) out_example = tf({"image": in_image_pil, "mask": mask_image_pil}) # Verify that the crop_top_left_yx value is correct. out_image = out_example["image"] out_image_np = np.array(out_image) crop_y, crop_x = out_example["crop_top_left_yx"] assert np.allclose( denormalize_image(out_image_np), in_image_np[crop_y : crop_y + resolution.height, crop_x : crop_x + resolution.width, :], ) # Verify that the mask was cropped in the same way as the image. out_mask = out_example["mask"] out_mask_np = np.array(out_mask) assert np.allclose( denormalize_mask(out_mask_np), mask_image_np[crop_y : crop_y + resolution.height, crop_x : crop_x + resolution.width], ) def test_sd_image_transform_center_crop_flip(): """Test SDImageTransform center cropping with a horizontal flip.""" # Input image is 5 x 9. in_image_np = np.arange(5 * 9 * 3, dtype=np.uint8).reshape((5, 9, 3)) in_image_pil = Image.fromarray(np.copy(in_image_np)) in_mask_np = np.arange(5 * 9, dtype=np.uint8).reshape((5, 9)) in_mask_pil = Image.fromarray(np.copy(in_mask_np)) # The target resolution is 5x3 (with center cropping and horizontal flipping). tf = SDImageTransform( image_field_names=["image", "mask"], fields_to_normalize_to_range_minus_one_to_one=["image"], resolution=Resolution(5, 3), center_crop=True, random_flip=True, ) # Note: We patch random.random() to force a horizontal flip to be applied. with unittest.mock.patch("random.random", return_value=0.0): out_example = tf({"image": in_image_pil, "mask": in_mask_pil}) # Verify that the correct region of the image was cropped/flipped. # For this comparison, we flip the in_image_np first, then apply the expected crop. out_image = out_example["image"] out_image_np = np.array(out_image) assert np.allclose(denormalize_image(out_image_np), in_image_np[:, ::-1, :][:, 3:-3, :]) assert out_example["crop_top_left_yx"] == (0, 3) # Verify that the correct region of the mask was cropped/flipped. out_mask = out_example["mask"] out_mask_np = np.array(out_mask) assert np.allclose(denormalize_mask(out_mask_np), in_mask_np[:, ::-1][:, 3:-3]) def test_sd_image_transform_random_crop_flip(): """Test SDImageTransform random cropping with a horizontal flip.""" # Input image is 5 x 9. in_image_np = np.arange(5 * 9 * 3, dtype=np.uint8).reshape((5, 9, 3)) in_image_pil = Image.fromarray(np.copy(in_image_np)) in_mask_np = np.arange(5 * 9, dtype=np.uint8).reshape((5, 9)) in_mask_pil = Image.fromarray(np.copy(in_mask_np)) # The target resolution is 5x3 (with random cropping and horizontal flipping). resolution = Resolution(5, 3) tf = SDImageTransform( image_field_names=["image", "mask"], fields_to_normalize_to_range_minus_one_to_one=["image"], resolution=resolution, center_crop=False, random_flip=True, ) # Note: We patch random.random() to force a horizontal flip to be applied. with unittest.mock.patch("random.random", return_value=0.0): out_example = tf({"image": in_image_pil, "mask": in_mask_pil}) # Verify that the crop_top_left_yx value is correct. # For this comparison, we flip the in_image_np first, then apply the expected crop. out_image = out_example["image"] out_image_np = np.array(out_image) crop_y, crop_x = out_example["crop_top_left_yx"] assert np.allclose( denormalize_image(out_image_np), in_image_np[:, ::-1, :][crop_y : crop_y + resolution.height, crop_x : crop_x + resolution.width, :], ) # Verify thath the mask was cropped in the same way as the image. out_mask = out_example["mask"] out_mask_np = np.array(out_mask) assert np.allclose( denormalize_mask(out_mask_np), in_mask_np[:, ::-1][crop_y : crop_y + resolution.height, crop_x : crop_x + resolution.width], ) def test_sd_image_transform_aspect_ratio_bucket_manager(): # Input image is 9 x 5. in_image_np = np.arange(9 * 5 * 3, dtype=np.uint8).reshape((9, 5, 3)) in_image_pil = Image.fromarray(np.copy(in_image_np)) in_mask_np = np.arange(9 * 5, dtype=np.uint8).reshape((9, 5)) in_mask_pil = Image.fromarray(np.copy(in_mask_np)) # Initialize SDImageTransform with an AspectRatioBucketManager that has a single 3x5 bucket. aspect_ratio_bucket_manager = AspectRatioBucketManager(buckets={Resolution(3, 5)}) tf = SDImageTransform( image_field_names=["image", "mask"], fields_to_normalize_to_range_minus_one_to_one=["image"], resolution=None, aspect_ratio_bucket_manager=aspect_ratio_bucket_manager, center_crop=True, ) out_example = tf({"image": in_image_pil, "mask": in_mask_pil}) # Verify that the correct region of the image was cropped. out_image = out_example["image"] out_image_np = np.array(out_image) assert np.allclose(denormalize_image(out_image_np), in_image_np[3:-3, :, :]) assert out_example["crop_top_left_yx"] == (3, 0) # Verify that the correct region of the mask was cropped. out_mask = out_example["mask"] out_mask_np = np.array(out_mask) assert np.allclose(denormalize_mask(out_mask_np), in_mask_np[3:-3, :]) @pytest.mark.parametrize( ["resolution", "aspect_ratio_bucket_manager"], [ (Resolution(512, 512), AspectRatioBucketManager({})), (None, None), ], ) def test_sd_image_transform_resolution_input_validation( resolution: Resolution | None, aspect_ratio_bucket_manager: AspectRatioBucketManager | None ): with pytest.raises(ValueError): _ = SDImageTransform( image_field_names=["image", "mask"], fields_to_normalize_to_range_minus_one_to_one=["image"], resolution=resolution, aspect_ratio_bucket_manager=aspect_ratio_bucket_manager, ) ================================================ FILE: tests/invoke_training/_shared/data/transforms/test_shuffle_caption_transform.py ================================================ from invoke_training._shared.data.transforms.shuffle_caption_transform import ShuffleCaptionTransform def test_shuffle_caption_transform(): tf = ShuffleCaptionTransform(field_name="test_field", seed=3) in_example = {"test_field": "prompt part 1, prompt part 2"} out_example = tf(in_example) # Note that the expected output depends on the seed. assert out_example == {"test_field": "prompt part 2, prompt part 1"} def test_shuffle_caption_transform_no_delimiter(): tf = ShuffleCaptionTransform(field_name="test_field") in_example = {"test_field": "prompt part 1"} out_example = tf(in_example) assert out_example == {"test_field": "prompt part 1"} ================================================ FILE: tests/invoke_training/_shared/data/transforms/test_template_caption_transform.py ================================================ import pytest from invoke_training._shared.data.transforms.template_caption_transform import ( TemplateCaptionTransform, ) def test_template_caption_transform(): tf = TemplateCaptionTransform( field_name="test_field", placeholder_str="placeholder", caption_templates=["template 1 {}"] ) in_example = {"existing": 2} out_example = tf(in_example) assert out_example == {"existing": 2, "test_field": "template 1 placeholder"} def test_template_caption_transform_seed(): field_name = "test_field" placeholder_str = "placeholder" caption_templates = ["template 1 {}", "template 2 {}"] tf = TemplateCaptionTransform( field_name=field_name, placeholder_str=placeholder_str, caption_templates=caption_templates, seed=123, ) # Run on 10 examples with baseline seed 123. out_examples = [tf({}) for _ in range(10)] # Run on 10 examples with same seed and assert that results match. tf = TemplateCaptionTransform( field_name=field_name, placeholder_str=placeholder_str, caption_templates=caption_templates, seed=123, ) out_examples_same_seed = [tf({}) for _ in range(10)] assert out_examples == out_examples_same_seed # Run on 10 examples with a different seed and assert that the results don't match. tf = TemplateCaptionTransform( field_name=field_name, placeholder_str=placeholder_str, caption_templates=caption_templates, seed=456, ) out_examples_diff_seed = [tf({}) for _ in range(10)] assert out_examples != out_examples_diff_seed def test_template_caption_transform_bad_templates(): tf = TemplateCaptionTransform( field_name="test_field", placeholder_str="placeholder", caption_templates=["template 1"] ) in_example = {"existing": 2} with pytest.raises(AssertionError): _ = tf(in_example) ================================================ FILE: tests/invoke_training/_shared/data/transforms/test_tensor_disk_cache.py ================================================ from pathlib import Path import pytest import torch from invoke_training._shared.data.transforms.tensor_disk_cache import TensorDiskCache def test_tensor_disk_cache_roundtrip(tmp_path: Path): """Test a TensorDiskCache cache roundtrip.""" cache = TensorDiskCache(str(tmp_path)) in_dict = {"test_tensor": torch.rand((1, 2, 3)), "test_tuple": (1, 2), "test_list": [3, 4], "test_scalar": 1} # Roundtrip cache.save(0, in_dict) out_dict = cache.load(0) assert set(in_dict.keys()) == set(out_dict.keys()) torch.testing.assert_close(out_dict["test_tensor"], in_dict["test_tensor"]) assert out_dict["test_tuple"] == in_dict["test_tuple"] assert out_dict["test_list"] == in_dict["test_list"] assert out_dict["test_scalar"] == in_dict["test_scalar"] def test_tensor_disk_cache_fail_overwrite(tmp_path): """Test that an attempt to overwrite an existing TensorDiskCache cache entry raises a ValueError.""" cache = TensorDiskCache(str(tmp_path)) in_dict = {"test_tensor": torch.rand((1, 2, 3))} cache.save(0, in_dict) with pytest.raises(AssertionError): cache.save(0, in_dict) ================================================ FILE: tests/invoke_training/_shared/data/utils/__init__.py ================================================ ================================================ FILE: tests/invoke_training/_shared/data/utils/test_aspect_ratio_bucket_manager.py ================================================ from contextlib import nullcontext import pytest from invoke_training._shared.data.utils.aspect_ratio_bucket_manager import AspectRatioBucketManager from invoke_training._shared.data.utils.resolution import Resolution @pytest.mark.parametrize( ["target_resolution", "start_dim", "end_dim", "divisible_by", "should_raise"], [ (1024, 512, 2048, 64, False), (1025, 512, 2048, 64, True), # target_resolution not divisible by divisible_by. (1024, 513, 2048, 64, True), # start_dim not divisible by divisible_by. (1024, 512, 2049, 64, True), # end_dim not divisible by divisible_by. (1024, 1024, 512, 64, True), # start_dim > end_dim. ], ) def test_build_aspect_ratio_buckets_input_validation( target_resolution: int, start_dim: int, end_dim: int, divisible_by: int, should_raise: bool ): """Test validation of all input params to AspectRatioBucketManager.build_aspect_ratio_buckets(...).""" expectation = pytest.raises(AssertionError) if should_raise else nullcontext() with expectation: _ = AspectRatioBucketManager.build_aspect_ratio_buckets( target_resolution=target_resolution, start_dim=start_dim, end_dim=end_dim, divisible_by=divisible_by, ) @pytest.mark.parametrize( ["target_resolution", "start_dim", "end_dim", "divisible_by", "expected"], [ # 1 bucket (1024, 1024, 1024, 64, {Resolution(1024, 1024)}), # Multiple buckets. ( 1024, 768, 1280, 128, { Resolution(768, 1280), Resolution(896, 1152), Resolution(1024, 1024), Resolution(1152, 896), Resolution(1280, 768), }, ), ], ) def test_build_aspect_ratio_buckets( target_resolution: int, start_dim: int, end_dim: int, divisible_by: int, expected: set[Resolution], ): buckets = AspectRatioBucketManager.build_aspect_ratio_buckets( target_resolution=target_resolution, start_dim=start_dim, end_dim=end_dim, divisible_by=divisible_by, ) assert buckets == expected @pytest.mark.parametrize( ["resolution", "expected_bucket"], [ (Resolution(1024, 1024), Resolution(1024, 1024)), # Exact match. (Resolution(128, 1024), Resolution(768, 1280)), # Small aspect ratio. (Resolution(1024, 128), Resolution(1280, 768)), # Large aspect ratio. ], ) def test_get_aspect_ratio_bucket(resolution: Resolution, expected_bucket: Resolution): arbm = AspectRatioBucketManager.from_constraints( target_resolution=1024, start_dim=768, end_dim=1280, divisible_by=128 ) nearest_bucket = arbm.get_aspect_ratio_bucket(resolution) assert nearest_bucket == expected_bucket ================================================ FILE: tests/invoke_training/_shared/data/utils/test_resize.py ================================================ import numpy as np import pytest from PIL import Image from invoke_training._shared.data.utils.resize import resize_to_cover from invoke_training._shared.data.utils.resolution import Resolution @pytest.mark.parametrize( ["in_resolution", "size_to_cover", "expected_resolution"], [ # Perfect match, no resize necessary. (Resolution(512, 768), Resolution(512, 768), Resolution(512, 768)), # Height matches, width covers, no resize necessary. (Resolution(768, 768), Resolution(768, 512), Resolution(768, 768)), # Width matches, height covers, no resize necessary. (Resolution(768, 768), Resolution(512, 768), Resolution(768, 768)), # Height matches, width does not cover, scale up. (Resolution(768, 256), Resolution(768, 512), Resolution(1536, 512)), # Width matches, height does not cover, scale up. (Resolution(256, 768), Resolution(512, 768), Resolution(512, 1536)), # Both width and height exceed target, scale down, limited by height. (Resolution(1024, 768), Resolution(768, 512), Resolution(768, 576)), # Both width and height exceed target, scale down, limited by width. (Resolution(768, 1024), Resolution(512, 768), Resolution(576, 768)), ], ) def test_resize_to_cover(in_resolution: Resolution, size_to_cover: Resolution, expected_resolution: Resolution): in_img = np.zeros((in_resolution.height, in_resolution.width, 3), dtype=np.uint8) in_img = Image.fromarray(in_img) out_img = resize_to_cover(in_img, size_to_cover) assert out_img.height == expected_resolution.height assert out_img.width == expected_resolution.width ================================================ FILE: tests/invoke_training/_shared/data/utils/test_resolution.py ================================================ import pytest from invoke_training._shared.data.utils.resolution import Resolution @pytest.mark.parametrize( ["input", "expected_resolution"], [ (5, Resolution(5, 5)), # From int. ((5, 6), Resolution(5, 6)), # From tuple[int, int]. (Resolution(5, 6), Resolution(5, 6)), # From Resolution. ], ) def test_resolution_parse(input, expected_resolution: Resolution): resolution = Resolution.parse(input) assert resolution == expected_resolution ================================================ FILE: tests/invoke_training/_shared/stable_diffusion/__init__.py ================================================ ================================================ FILE: tests/invoke_training/_shared/stable_diffusion/test_base_model_version.py ================================================ import pytest from transformers import PretrainedConfig from invoke_training._shared.stable_diffusion.base_model_version import ( BaseModelVersionEnum, check_base_model_version, get_base_model_version, ) @pytest.mark.loads_model @pytest.mark.parametrize( ["diffusers_model_name", "expected_version"], [ ("runwayml/stable-diffusion-v1-5", BaseModelVersionEnum.STABLE_DIFFUSION_V1), ("stabilityai/stable-diffusion-2-1", BaseModelVersionEnum.STABLE_DIFFUSION_V2), ("stabilityai/stable-diffusion-xl-base-1.0", BaseModelVersionEnum.STABLE_DIFFUSION_SDXL_BASE), ("stabilityai/stable-diffusion-xl-refiner-1.0", BaseModelVersionEnum.STABLE_DIFFUSION_SDXL_REFINER), ], ) def test_get_base_model_version(diffusers_model_name: str, expected_version: BaseModelVersionEnum): """Test get_base_model_version(...) with one test model for each supported version.""" # Check if the diffusers_model_name model is downloaded and xfail if not. # This check ensures that users don't have to download all of the test models just to run the test suite. try: _ = PretrainedConfig.from_pretrained( pretrained_model_name_or_path=diffusers_model_name, subfolder="unet", local_files_only=True, ) except OSError: pytest.xfail(f"'{diffusers_model_name}' is not downloaded.") version = get_base_model_version(diffusers_model_name) assert version == expected_version @pytest.mark.loads_model def test_check_base_model_version_pass(): """Test that check_base_model_version(...) does not raise an Exception when the model is valid.""" check_base_model_version({BaseModelVersionEnum.STABLE_DIFFUSION_V1}, "runwayml/stable-diffusion-v1-5") @pytest.mark.loads_model def test_check_base_model_version_fail(): """Test that check_base_model_version(...) raises a ValueError when the model is invalid.""" with pytest.raises(ValueError): check_base_model_version({BaseModelVersionEnum.STABLE_DIFFUSION_V2}, "runwayml/stable-diffusion-v1-5") ================================================ FILE: tests/invoke_training/_shared/stable_diffusion/test_lora_checkpoint_utils.py ================================================ from pathlib import Path import pytest from invoke_training._shared.stable_diffusion.lora_checkpoint_utils import ( convert_sd_peft_checkpoint_to_kohya_state_dict, ) def test_convert_sd_peft_checkpoint_to_kohya_state_dict_raise_on_empty_directory(tmp_path: Path): with pytest.raises(ValueError, match="No checkpoint files found in directory"): convert_sd_peft_checkpoint_to_kohya_state_dict( in_checkpoint_dir=tmp_path, out_checkpoint_file=tmp_path / "out.safetensors" ) def test_convert_sd_peft_checkpoint_to_kohya_state_dict_raise_on_unexpected_subdirectory(tmp_path: Path): subdirectory = tmp_path / "subdir" subdirectory.mkdir() with pytest.raises(ValueError, match=f"Unrecognized checkpoint directory: '{subdirectory}'."): convert_sd_peft_checkpoint_to_kohya_state_dict( in_checkpoint_dir=tmp_path, out_checkpoint_file=tmp_path / "out.safetensors" ) ================================================ FILE: tests/invoke_training/_shared/stable_diffusion/test_model_loading_utils.py ================================================ import logging from pathlib import Path import pytest import torch from transformers import CLIPTextModel, CLIPTokenizer from invoke_training._shared.stable_diffusion.model_loading_utils import load_models_sd, load_models_sdxl from .ti_embedding_checkpoint_fixture import ( # noqa: F401 sdv1_embedding_path, sdxl_embedding_path, ) @pytest.mark.loads_model def test_load_models_sd(sdv1_embedding_path): # noqa: F811 model_name = "runwayml/stable-diffusion-v1-5" tokenizer, noise_scheduler, text_encoder, vae, unet = load_models_sd( logger=logging.getLogger(__name__), model_name_or_path=model_name, hf_variant="fp16", base_embeddings={"special_test_token": str(sdv1_embedding_path)}, ) token_ids = tokenizer.encode("special_test_token special_test_token_1", add_special_tokens=False) assert len(token_ids) == 2 token_embeds = text_encoder.get_input_embeddings().weight.data for token_id in token_ids: # The embedding should be all zeros, because that is how it was initialized in the sdv1_embedding_path # fixture. assert torch.allclose(token_embeds[token_id], torch.zeros_like(token_embeds[token_id])) @pytest.mark.loads_model def test_load_models_sdxl(sdxl_embedding_path: Path): # noqa: F811 model_name = "stabilityai/stable-diffusion-xl-base-1.0" tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, unet = load_models_sdxl( logger=logging.getLogger(__name__), model_name_or_path=model_name, hf_variant="fp16", base_embeddings={"special_test_token": str(sdxl_embedding_path)}, ) # Validate that the embeddings were applied correctly. def validate_ti_embeddings(tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel): token_ids = tokenizer.encode("special_test_token special_test_token_1", add_special_tokens=False) assert len(token_ids) == 2 token_embeds = text_encoder.get_input_embeddings().weight.data for token_id in token_ids: # The embedding should be all zeros, because that is how it was initialized in the sdxl_embedding_path # fixture. assert torch.allclose(token_embeds[token_id], torch.zeros_like(token_embeds[token_id])) validate_ti_embeddings(tokenizer_1, text_encoder_1) validate_ti_embeddings(tokenizer_2, text_encoder_2) ================================================ FILE: tests/invoke_training/_shared/stable_diffusion/test_textual_inversion.py ================================================ import logging from pathlib import Path import pytest import torch from invoke_training._shared.stable_diffusion.model_loading_utils import load_models_sd from invoke_training._shared.stable_diffusion.textual_inversion import ( _expand_placeholder_token, initialize_placeholder_tokens_from_initial_embedding, initialize_placeholder_tokens_from_initial_phrase, initialize_placeholder_tokens_from_initializer_token, ) from .ti_embedding_checkpoint_fixture import sdv1_embedding_path # noqa: F401 @pytest.mark.parametrize( ["placeholder_token", "num_vectors", "expected_placeholder_tokens"], [("abc", 1, ["abc"]), ("abc", 2, ["abc", "abc_1"]), ("abc", 3, ["abc", "abc_1", "abc_2"])], ) def test_expand_placeholder_token(placeholder_token: str, num_vectors: int, expected_placeholder_tokens: list[str]): assert _expand_placeholder_token(placeholder_token, num_vectors) == expected_placeholder_tokens def test_expand_placeholder_token_raises_on_invalid_num_vectors(): with pytest.raises(ValueError): _expand_placeholder_token("abc", 0) @pytest.mark.loads_model def test_initialize_placeholder_tokens_from_initializer_token(): tokenizer, noise_scheduler, text_encoder, vae, unet = load_models_sd( logger=logging.getLogger(__name__), model_name_or_path="runwayml/stable-diffusion-v1-5", hf_variant="fp16" ) initializer_token = "dog" num_vectors = 2 placeholder_tokens, placeholder_token_ids = initialize_placeholder_tokens_from_initializer_token( tokenizer=tokenizer, text_encoder=text_encoder, initializer_token=initializer_token, placeholder_token="dog_placeholder", num_vectors=num_vectors, logger=logging.getLogger(), ) assert len(placeholder_tokens) == num_vectors assert len(placeholder_token_ids) == num_vectors assert placeholder_tokens == ["dog_placeholder", "dog_placeholder_1"] token_embeds = text_encoder.get_input_embeddings().weight.data initializer_token_id = tokenizer.encode(initializer_token, add_special_tokens=False)[0] with torch.no_grad(): for placeholder_token_id in placeholder_token_ids: assert torch.allclose(token_embeds[placeholder_token_id], token_embeds[initializer_token_id]) @pytest.mark.loads_model def test_initialize_placeholder_tokens_from_initial_phrase(): tokenizer, noise_scheduler, text_encoder, vae, unet = load_models_sd( logger=logging.getLogger(__name__), model_name_or_path="runwayml/stable-diffusion-v1-5", hf_variant="fp16" ) initial_phrase = "little brown dog" placeholder_tokens, placeholder_token_ids = initialize_placeholder_tokens_from_initial_phrase( tokenizer=tokenizer, text_encoder=text_encoder, initial_phrase=initial_phrase, placeholder_token="dog_placeholder", ) expected_num_vectors = 3 assert len(placeholder_tokens) == expected_num_vectors assert len(placeholder_token_ids) == expected_num_vectors assert placeholder_tokens == ["dog_placeholder", "dog_placeholder_1", "dog_placeholder_2"] token_embeds = text_encoder.get_input_embeddings().weight.data initial_token_ids = tokenizer.encode(initial_phrase, add_special_tokens=False) assert len(initial_token_ids) == expected_num_vectors with torch.no_grad(): for placeholder_token_id, initial_token_id in zip(placeholder_token_ids, initial_token_ids): assert torch.allclose(token_embeds[placeholder_token_id], token_embeds[initial_token_id]) @pytest.mark.loads_model def test_initialize_placeholder_tokens_from_initial_embedding(sdv1_embedding_path: Path): # noqa: F811 tokenizer, noise_scheduler, text_encoder, vae, unet = load_models_sd( logger=logging.getLogger(__name__), model_name_or_path="runwayml/stable-diffusion-v1-5", hf_variant="fp16" ) placeholder_token = "custom_token" num_vectors = 2 placeholder_tokens, placeholder_token_ids = initialize_placeholder_tokens_from_initial_embedding( tokenizer=tokenizer, text_encoder=text_encoder, initial_embedding_file=str(sdv1_embedding_path), placeholder_token=placeholder_token, num_vectors=num_vectors, ) assert len(placeholder_tokens) == num_vectors assert len(placeholder_token_ids) == num_vectors assert placeholder_tokens == ["custom_token", "custom_token_1"] token_embeds = text_encoder.get_input_embeddings().weight.data with torch.no_grad(): for placeholder_token_id in placeholder_token_ids: # The placeholder embeddings should be initialized to zero, because this is how they are initialized in the # dummy sdv1_embedding_path checkpoint. assert torch.allclose( token_embeds[placeholder_token_id], torch.zeros_like(token_embeds[placeholder_token_id]) ) ================================================ FILE: tests/invoke_training/_shared/stable_diffusion/ti_embedding_checkpoint_fixture.py ================================================ import pytest import torch from invoke_training._shared.checkpoints.serialization import save_state_dict @pytest.fixture(scope="session") def sdv1_embedding_path(tmp_path_factory: pytest.TempPathFactory): """A fixture that writes a dummy SD v1 TI embedding to a temp dir and returns the embedding path. Note that the 'session' scope is used to share the same directory across all tests in a session. Refer to https://docs.pytest.org/en/7.4.x/how-to/tmp_path.html#the-tmp-path-factory-fixture for details on the use of tmp_path_factory. """ tmp_dir = tmp_path_factory.mktemp("embeddings") embedding_state_dict = {"custom_token": torch.zeros((2, 768))} embedding_path = tmp_dir / "embedding.safetensors" save_state_dict(embedding_state_dict, embedding_path) return embedding_path @pytest.fixture(scope="session") def sdxl_embedding_path(tmp_path_factory: pytest.TempPathFactory): """A fixture that writes a dummy SDXL TI embedding to a temp dir and returns the embedding path. Note that the 'session' scope is used to share the same directory across all tests in a session. Refer to https://docs.pytest.org/en/7.4.x/how-to/tmp_path.html#the-tmp-path-factory-fixture for details on the use of tmp_path_factory. """ tmp_dir = tmp_path_factory.mktemp("embeddings") embedding_state_dict = { "clip_l": torch.zeros((2, 768)), "clip_g": torch.zeros((2, 1280)), } embedding_path = tmp_dir / "embedding.safetensors" save_state_dict(embedding_state_dict, embedding_path) return embedding_path ================================================ FILE: tests/invoke_training/_shared/utils/test_jsonl.py ================================================ from pathlib import Path from invoke_training._shared.utils.jsonl import load_jsonl, save_jsonl def test_jsonl_roundtrip(tmp_path: Path): in_objs = [{"a": 1, "b": 2}, {"a": 1, "b": 2}] jsonl_path = tmp_path / "test.jsonl" save_jsonl(in_objs, jsonl_path) out_objs = load_jsonl(jsonl_path) assert in_objs == out_objs ================================================ FILE: tests/invoke_training/config/pipelines/test_pipeline_config.py ================================================ import glob from pathlib import Path import yaml from pydantic import TypeAdapter from invoke_training.config.pipeline_config import PipelineConfig def test_pipeline_config(): """Test that all sample pipeline configs can be parsed as PipelineConfigs.""" cur_file = Path(__file__) config_dir = cur_file.parent.parent.parent.parent.parent / "src/invoke_training/sample_configs" config_files = glob.glob(str(config_dir) + "/**/*.yaml", recursive=True) assert len(config_files) > 0 for config_file in config_files: with open(config_file, "r") as f: cfg = yaml.safe_load(f) pipeline_adapter: TypeAdapter[PipelineConfig] = TypeAdapter(PipelineConfig) try: _ = pipeline_adapter.validate_python(cfg) except Exception as e: raise Exception(f"Error parsing config file: {config_file}") from e ================================================ FILE: tests/invoke_training/model_merge/__init__.py ================================================ ================================================ FILE: tests/invoke_training/model_merge/test_merge_models.py ================================================ import math from typing import Literal import pytest import torch from invoke_training.model_merge.merge_models import merge_models from .utils import state_dicts_are_close def test_merge_models_raises_on_not_enough_state_dicts(): with pytest.raises(ValueError, match="Must provide >=2 models to merge."): _ = merge_models(state_dicts=[{}], weights=[0.5], merge_method="LERP") def test_merge_models_raises_on_mismatched_weights(): with pytest.raises(ValueError, match="Must provide a weight for each model."): _ = merge_models(state_dicts=[{}, {}], weights=[0.5, 0.5, 0.5], merge_method="LERP") @pytest.mark.parametrize( ["state_dicts", "weights", "merge_method", "expected_state_dict"], [ # Lerp. ( [ {"a": torch.tensor(1.0), "b": torch.tensor(2.0)}, {"a": torch.tensor(3.0), "b": torch.tensor(4.0)}, ], [1.0, 1.0], "LERP", {"a": torch.tensor(2.0), "b": torch.tensor(3.0)}, ), # Lerp with unbalanced weights. ( [ {"a": torch.tensor(1.0), "b": torch.tensor(2.0)}, {"a": torch.tensor(3.0), "b": torch.tensor(4.0)}, ], [1.0, 3.0], "LERP", {"a": torch.tensor(1.0 * 0.25 + 3.0 * 0.75), "b": torch.tensor(2.0 * 0.25 + 4.0 * 0.75)}, ), # Lerp with more than 2 state dicts. ( [ {"a": torch.tensor(1.0), "b": torch.tensor(2.0)}, {"a": torch.tensor(2.0), "b": torch.tensor(3.0)}, {"a": torch.tensor(3.0), "b": torch.tensor(4.0)}, ], [1.0, 1.0, 1.0], "LERP", {"a": torch.tensor(2.0), "b": torch.tensor(3.0)}, ), # Slerp with scalar tensors falls back to lerp. ( [ {"a": torch.tensor(1.0), "b": torch.tensor(2.0)}, {"a": torch.tensor(3.0), "b": torch.tensor(4.0)}, ], [1.0, 1.0], "SLERP", {"a": torch.tensor(2.0), "b": torch.tensor(3.0)}, ), # Slerp with colinear vector tensors falls back to lerp. ( [ {"a": torch.tensor([1.0, 2.0])}, {"a": torch.tensor([2.0, 4.0])}, ], [1.0, 1.0], "SLERP", {"a": torch.tensor([1.5, 3.0])}, ), # Slerp with orthogonal vector tensors. ( [ {"a": torch.tensor([1.0, 0.0])}, {"a": torch.tensor([0.0, 1.0])}, ], [1.0, 1.0], "SLERP", {"a": torch.tensor([math.sin(math.pi / 4), math.sin(math.pi / 4)])}, ), ], ) def test_merge_models( state_dicts: list[dict[str, torch.Tensor]], weights: list[float], merge_method: Literal["LERP", "SLERP"], expected_state_dict: dict[str, torch.Tensor], ): merged_state_dict = merge_models(state_dicts=state_dicts, weights=weights, merge_method=merge_method) assert state_dicts_are_close(merged_state_dict, expected_state_dict) ================================================ FILE: tests/invoke_training/model_merge/test_merge_tasks_to_base.py ================================================ from typing import Literal import pytest import torch from invoke_training.model_merge.merge_tasks_to_base import merge_tasks_to_base_model from .utils import state_dicts_are_close def test_merge_raises_on_mismatched_weights(): with pytest.raises(ValueError, match="Must provide a weight for each model."): _ = merge_tasks_to_base_model({}, [{}, {}], [0.5, 0.5, 0.5]) @pytest.mark.parametrize( ["base_state_dict", "task_state_dicts", "task_weights", "density", "merge_method", "expected_state_dict"], [ # TIES. ( {"a": torch.tensor([1.0, 2.0]), "b": torch.tensor([3.0, 4.0])}, [ {"a": torch.tensor([2.0, 7.0]), "b": torch.tensor([3.0, 6.0])}, {"a": torch.tensor([7.0, 3.0]), "b": torch.tensor([3.0, 7.0])}, ], [1.0, 1.0], 0.5, "TIES", # Expected task diff state dict: # {"a": torch.tensor([1.0, 5.0]), "b": torch.tensor([0.0, 2.0])}, # {"a": torch.tensor([6.0, 1.0]), "b": torch.tensor([0.0, 3.0])}, # Expected merged diff state dict: # {"a": torch.tensor([6.0, 5.0]), "b": torch.tensor([0.0, 2.5])}, # Expected final result: {"a": torch.tensor([7.0, 7.0]), "b": torch.tensor([3.0, 6.5])}, ), # DARE_LINEAR. ( {"a": torch.tensor([1.0, 2.0]), "b": torch.tensor([3.0, 4.0])}, [ {"a": torch.tensor([2.0, 7.0]), "b": torch.tensor([3.0, 6.0])}, {"a": torch.tensor([7.0, 3.0]), "b": torch.tensor([3.0, 7.0])}, ], [1.0, 1.0], # Set density to 1.0 so that we can set an expected result without having to handle seeding the RNG. 1.0, "DARE_LINEAR", {"a": torch.tensor([8.0, 8.0]), "b": torch.tensor([3.0, 9.0])}, ), # DARE_TIES. ( {"a": torch.tensor([1.0, 2.0]), "b": torch.tensor([3.0, 4.0])}, [ {"a": torch.tensor([2.0, 7.0]), "b": torch.tensor([3.0, 6.0])}, {"a": torch.tensor([7.0, 3.0]), "b": torch.tensor([3.0, 7.0])}, ], [1.0, 1.0], # Set density to 1.0 so that we can set an expected result without having to handle seeding the RNG. 1.0, "DARE_TIES", {"a": torch.tensor([4.5, 5.0]), "b": torch.tensor([3.0, 6.5])}, ), ], ) def test_merge_ties( base_state_dict: dict[str, torch.Tensor], task_state_dicts: list[dict[str, torch.Tensor]], task_weights: list[float], density: float, merge_method: Literal["TIES", "DARE_LINEAR", "DARE_TIES"], expected_state_dict: dict[str, torch.Tensor], ): merged_state_dict = merge_tasks_to_base_model( base_state_dict=base_state_dict, task_state_dicts=task_state_dicts, task_weights=task_weights, density=density, merge_method=merge_method, ) assert state_dicts_are_close(merged_state_dict, expected_state_dict) ================================================ FILE: tests/invoke_training/model_merge/utils.py ================================================ import torch def state_dicts_are_close(a: dict[str, torch.Tensor], b: dict[str, torch.Tensor]) -> bool: """Helper function for comparing two state dicts.""" return all(torch.allclose(a[key], b[key]) for key in a.keys()) ================================================ FILE: tests/invoke_training/ui/utils/test_prompts.py ================================================ import pytest from invoke_training.ui.utils.prompts import ( convert_pos_neg_prompts_to_ui_prompts, convert_ui_prompts_to_pos_neg_prompts, split_pos_neg_prompts, ) @pytest.mark.parametrize( ["prompt", "expected_positive_prompt", "expected_negative_prompt"], [ # Simple positive and negative prompt. ("positive prompt[NEG]negative prompt", "positive prompt", "negative prompt"), # Positive prompt with no negative prompt. ("positive prompt", "positive prompt", ""), # Empty prompt. ("", "", ""), ], ) def test_split_pos_neg_prompts(prompt: str, expected_positive_prompt: str, expected_negative_prompt: str): positive_prompt, negative_prompt = split_pos_neg_prompts(prompt) assert positive_prompt == expected_positive_prompt assert negative_prompt == expected_negative_prompt @pytest.mark.parametrize( "prompt", [ # Multiple negative prompt delimiters. "positive prompt[NEG]negative prompt[NEG]negative prompt", ], ) def test_split_pos_neg_prompts_raises_value_error(prompt: str): with pytest.raises(ValueError): split_pos_neg_prompts(prompt) # Test cases for conversion between UI prompts and positive/negative prompts. # Each test case consists of: (ui_prompts, positive_prompts, negative_prompts) prompt_conversion_test_cases = [ # Positive prompts. ( "positive prompt 1\npositive prompt 2\npositive prompt 3", ["positive prompt 1", "positive prompt 2", "positive prompt 3"], None, ), # Positive prompts with trailing \n. ( "positive prompt 1\npositive prompt 2\npositive prompt 3\n", ["positive prompt 1", "positive prompt 2", "positive prompt 3"], None, ), # Positive and negative prompts. ( "positive prompt 1[NEG]negative prompt 1\npositive prompt 2[NEG]negative prompt 2\n" "positive prompt 3[NEG]negative prompt 3\n", ["positive prompt 1", "positive prompt 2", "positive prompt 3"], ["negative prompt 1", "negative prompt 2", "negative prompt 3"], ), # Some missing negative prompts. ( "positive prompt 1[NEG]negative prompt 1\npositive prompt 2\npositive prompt 3[NEG]negative prompt 3\n", ["positive prompt 1", "positive prompt 2", "positive prompt 3"], ["negative prompt 1", "", "negative prompt 3"], ), ] @pytest.mark.parametrize( ["ui_prompts", "expected_positive_prompts", "expected_negative_prompts"], prompt_conversion_test_cases ) def test_convert_ui_prompts_to_pos_neg_prompts( ui_prompts: str, expected_positive_prompts: list[str], expected_negative_prompts: list[str | None] | None ): positive_prompts, negative_prompts = convert_ui_prompts_to_pos_neg_prompts(ui_prompts) assert positive_prompts == expected_positive_prompts assert negative_prompts == expected_negative_prompts @pytest.mark.parametrize(["expected_ui_prompts", "positive_prompts", "negative_prompts"], prompt_conversion_test_cases) def test_convert_pos_neg_prompts_to_ui_prompts( expected_ui_prompts: str, positive_prompts: list[str], negative_prompts: list[str | None] | None ): ui_prompts = convert_pos_neg_prompts_to_ui_prompts(positive_prompts, negative_prompts) assert ui_prompts == expected_ui_prompts.strip()